Numerics/Statistics/itkKdTree.h
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkKdTree_h
00018 #define __itkKdTree_h
00019
00020 #include <queue>
00021 #include <vector>
00022
00023 #include "itkMacro.h"
00024 #include "itkPoint.h"
00025 #include "itkSize.h"
00026 #include "itkObject.h"
00027 #include "itkNumericTraits.h"
00028 #include "itkArray.h"
00029
00030 #include "itkSample.h"
00031 #include "itkSubsample.h"
00032
00033 #include "itkEuclideanDistance.h"
00034
00035 namespace itk {
00036 namespace Statistics {
00037
00062 template< class TSample >
00063 struct KdTreeNode
00064 {
00066 typedef KdTreeNode< TSample> Self;
00067
00069 typedef typename TSample::MeasurementType MeasurementType;
00070
00072 typedef Array< double > CentroidType;
00073
00076 typedef typename TSample::InstanceIdentifier InstanceIdentifier;
00077
00080 virtual bool IsTerminal() const = 0;
00081
00087 virtual void GetParameters(unsigned int &partitionDimension,
00088 MeasurementType &partitionValue) const = 0;
00089
00091 virtual Self* Left() = 0;
00092 virtual const Self* Left() const = 0;
00094
00096 virtual Self* Right() = 0;
00097 virtual const Self* Right() const = 0;
00099
00102 virtual unsigned int Size() const = 0;
00103
00105 virtual void GetWeightedCentroid(CentroidType ¢roid) = 0;
00106
00108 virtual void GetCentroid(CentroidType ¢roid) = 0;
00109
00111 virtual InstanceIdentifier GetInstanceIdentifier(size_t index) const = 0;
00112
00114 virtual void AddInstanceIdentifier(InstanceIdentifier id) = 0;
00115
00117 virtual ~KdTreeNode() {};
00118 };
00119
00131 template< class TSample >
00132 struct KdTreeNonterminalNode: public KdTreeNode< TSample >
00133 {
00134 typedef KdTreeNode< TSample > Superclass;
00135 typedef typename Superclass::MeasurementType MeasurementType;
00136 typedef typename Superclass::CentroidType CentroidType;
00137 typedef typename Superclass::InstanceIdentifier InstanceIdentifier;
00138
00139 KdTreeNonterminalNode(unsigned int partitionDimension,
00140 MeasurementType partitionValue,
00141 Superclass* left,
00142 Superclass* right);
00143
00144 virtual ~KdTreeNonterminalNode() {}
00145
00146 virtual bool IsTerminal() const
00147 { return false; }
00148
00149 void GetParameters(unsigned int &partitionDimension,
00150 MeasurementType &partitionValue) const;
00151
00152 Superclass* Left()
00153 { return m_Left; }
00154
00155 Superclass* Right()
00156 { return m_Right; }
00157
00158 const Superclass* Left() const
00159 { return m_Left; }
00160
00161 const Superclass* Right() const
00162 { return m_Right; }
00163
00164 unsigned int Size() const
00165 { return 0; }
00166
00167 void GetWeightedCentroid( CentroidType & )
00168 {}
00169
00170 void GetCentroid( CentroidType & )
00171 {}
00172
00173
00174
00175
00176 InstanceIdentifier GetInstanceIdentifier(size_t) const
00177 { return this->m_InstanceIdentifier; }
00178
00179 void AddInstanceIdentifier(InstanceIdentifier valueId)
00180 { this->m_InstanceIdentifier = valueId; }
00181
00182 private:
00183 unsigned int m_PartitionDimension;
00184 MeasurementType m_PartitionValue;
00185 InstanceIdentifier m_InstanceIdentifier;
00186 Superclass* m_Left;
00187 Superclass* m_Right;
00188 };
00189
00204 template< class TSample >
00205 struct KdTreeWeightedCentroidNonterminalNode: public KdTreeNode< TSample >
00206 {
00207 typedef KdTreeNode< TSample > Superclass;
00208 typedef typename Superclass::MeasurementType MeasurementType;
00209 typedef typename Superclass::CentroidType CentroidType;
00210 typedef typename Superclass::InstanceIdentifier InstanceIdentifier;
00211 typedef typename TSample::MeasurementVectorSizeType MeasurementVectorSizeType;
00212
00213 KdTreeWeightedCentroidNonterminalNode(unsigned int partitionDimension,
00214 MeasurementType partitionValue,
00215 Superclass* left,
00216 Superclass* right,
00217 CentroidType ¢roid,
00218 unsigned int size);
00219 virtual ~KdTreeWeightedCentroidNonterminalNode() {}
00220
00221 virtual bool IsTerminal() const
00222 { return false; }
00223
00224 void GetParameters(unsigned int &partitionDimension,
00225 MeasurementType &partitionValue) const;
00226
00228 MeasurementVectorSizeType GetMeasurementVectorSize() const
00229 {
00230 return m_MeasurementVectorSize;
00231 }
00232
00233 Superclass* Left()
00234 { return m_Left; }
00235
00236 Superclass* Right()
00237 { return m_Right; }
00238
00239 const Superclass* Left() const
00240 { return m_Left; }
00241
00242 const Superclass* Right() const
00243 { return m_Right; }
00244
00245 unsigned int Size() const
00246 { return m_Size; }
00247
00248 void GetWeightedCentroid(CentroidType ¢roid)
00249 { centroid = m_WeightedCentroid; }
00250
00251 void GetCentroid(CentroidType ¢roid)
00252 { centroid = m_Centroid; }
00253
00254 InstanceIdentifier GetInstanceIdentifier(size_t) const
00255 { return this->m_InstanceIdentifier; }
00256
00257 void AddInstanceIdentifier(InstanceIdentifier valueId)
00258 { this->m_InstanceIdentifier = valueId; }
00259
00260 private:
00261 MeasurementVectorSizeType m_MeasurementVectorSize;
00262 unsigned int m_PartitionDimension;
00263 MeasurementType m_PartitionValue;
00264 CentroidType m_WeightedCentroid;
00265 CentroidType m_Centroid;
00266 InstanceIdentifier m_InstanceIdentifier;
00267 unsigned int m_Size;
00268 Superclass* m_Left;
00269 Superclass* m_Right;
00270 };
00271
00272
00284 template< class TSample >
00285 struct KdTreeTerminalNode: public KdTreeNode< TSample >
00286 {
00287 typedef KdTreeNode< TSample > Superclass;
00288 typedef typename Superclass::MeasurementType MeasurementType;
00289 typedef typename Superclass::CentroidType CentroidType;
00290 typedef typename Superclass::InstanceIdentifier InstanceIdentifier;
00291
00292 KdTreeTerminalNode() {}
00293
00294 virtual ~KdTreeTerminalNode() {}
00295
00296 bool IsTerminal() const
00297 { return true; }
00298
00299 void GetParameters(unsigned int &,
00300 MeasurementType &) const {}
00301
00302 Superclass* Left()
00303 { return 0; }
00304
00305 Superclass* Right()
00306 { return 0; }
00307
00308 const Superclass* Left() const
00309 { return 0; }
00310
00311 const Superclass* Right() const
00312 { return 0; }
00313
00314 unsigned int Size() const
00315 { return static_cast<unsigned int>( m_InstanceIdentifiers.size() ); }
00316
00317 void GetWeightedCentroid(CentroidType &)
00318 {}
00319
00320 void GetCentroid(CentroidType &)
00321 {}
00322
00323 InstanceIdentifier GetInstanceIdentifier(size_t index) const
00324 { return m_InstanceIdentifiers[index]; }
00325
00326 void AddInstanceIdentifier(InstanceIdentifier id)
00327 { m_InstanceIdentifiers.push_back(id);}
00328
00329 private:
00330 std::vector< InstanceIdentifier > m_InstanceIdentifiers;
00331 };
00332
00365 template < class TSample >
00366 class ITK_EXPORT KdTree : public Object
00367 {
00368 public:
00370 typedef KdTree Self;
00371 typedef Object Superclass;
00372 typedef SmartPointer<Self> Pointer;
00373 typedef SmartPointer<const Self> ConstPointer;
00374
00376 itkTypeMacro(KdTree, Object);
00377
00379 itkNewMacro(Self);
00380
00382 typedef TSample SampleType;
00383 typedef typename TSample::MeasurementVectorType MeasurementVectorType;
00384 typedef typename TSample::MeasurementType MeasurementType;
00385 typedef typename TSample::InstanceIdentifier InstanceIdentifier;
00386 typedef typename TSample::FrequencyType FrequencyType;
00387
00388 typedef unsigned int MeasurementVectorSizeType;
00389
00392 itkGetConstMacro( MeasurementVectorSize, MeasurementVectorSizeType );
00393
00395 typedef EuclideanDistance< MeasurementVectorType > DistanceMetricType;
00396
00398 typedef KdTreeNode< TSample > KdTreeNodeType;
00399
00403 typedef std::pair< InstanceIdentifier, double > NeighborType;
00404
00405 typedef std::vector< InstanceIdentifier > InstanceIdentifierVectorType;
00406
00416 class NearestNeighbors
00417 {
00418 public:
00420 NearestNeighbors() {}
00421
00423 ~NearestNeighbors() {}
00424
00427 void resize(unsigned int k)
00428 {
00429 m_Identifiers.clear();
00430 m_Identifiers.resize(k, NumericTraits< unsigned long >::max());
00431 m_Distances.clear();
00432 m_Distances.resize(k, NumericTraits< double >::max());
00433 m_FarthestNeighborIndex = 0;
00434 }
00436
00438 double GetLargestDistance()
00439 { return m_Distances[m_FarthestNeighborIndex]; }
00440
00443 void ReplaceFarthestNeighbor(InstanceIdentifier id, double distance)
00444 {
00445 m_Identifiers[m_FarthestNeighborIndex] = id;
00446 m_Distances[m_FarthestNeighborIndex] = distance;
00447 double farthestDistance = NumericTraits< double >::min();
00448 const unsigned int size = static_cast<unsigned int>( m_Distances.size() );
00449 for ( unsigned int i = 0; i < size; i++ )
00450 {
00451 if ( m_Distances[i] > farthestDistance )
00452 {
00453 farthestDistance = m_Distances[i];
00454 m_FarthestNeighborIndex = i;
00455 }
00456 }
00457 }
00459
00461 const InstanceIdentifierVectorType & GetNeighbors() const
00462 { return m_Identifiers; }
00463
00466 InstanceIdentifier GetNeighbor(unsigned int index) const
00467 { return m_Identifiers[index]; }
00468
00470 const std::vector< double >& GetDistances() const
00471 { return m_Distances; }
00472
00473 private:
00475 unsigned int m_FarthestNeighborIndex;
00476
00478 InstanceIdentifierVectorType m_Identifiers;
00479
00482 std::vector< double > m_Distances;
00483 };
00484
00487 void SetBucketSize(unsigned int size);
00488
00491 void SetSample(const TSample* sample);
00492
00494 const TSample* GetSample() const
00495 { return m_Sample; }
00496
00497 unsigned long Size() const
00498 { return m_Sample->Size(); }
00499
00504 KdTreeNodeType* GetEmptyTerminalNode()
00505 { return m_EmptyTerminalNode; }
00506
00509 void SetRoot(KdTreeNodeType* root)
00510 { m_Root = root; }
00511
00513 KdTreeNodeType* GetRoot()
00514 { return m_Root; }
00515
00518 const MeasurementVectorType & GetMeasurementVector(InstanceIdentifier id) const
00519 { return m_Sample->GetMeasurementVector(id); }
00520
00523 FrequencyType GetFrequency(InstanceIdentifier id) const
00524 { return m_Sample->GetFrequency( id ); }
00525
00527 DistanceMetricType* GetDistanceMetric()
00528 { return m_DistanceMetric.GetPointer(); }
00529
00531 void Search(const MeasurementVectorType &query,
00532 unsigned int numberOfNeighborsRequested,
00533 InstanceIdentifierVectorType& result) const;
00534
00536 void Search(const MeasurementVectorType &query,
00537 double radius,
00538 InstanceIdentifierVectorType& result) const;
00539
00542 int GetNumberOfVisits() const
00543 { return m_NumberOfVisits; }
00544
00550 bool BallWithinBounds(const MeasurementVectorType &query,
00551 MeasurementVectorType &lowerBound,
00552 MeasurementVectorType &upperBound,
00553 double radius) const;
00554
00558 bool BoundsOverlapBall(const MeasurementVectorType &query,
00559 MeasurementVectorType &lowerBound,
00560 MeasurementVectorType &upperBound,
00561 double radius) const;
00562
00564 void DeleteNode(KdTreeNodeType *node);
00565
00567 void PrintTree( std::ostream & os ) const;
00568
00570 void PrintTree(KdTreeNodeType *node, unsigned int level,
00571 unsigned int activeDimension,
00572 std::ostream & os = std::cout ) const;
00573
00576 void PlotTree( std::ostream & os ) const;
00577
00579 void PlotTree(KdTreeNodeType *node, std::ostream & os = std::cout ) const;
00580
00581
00582 typedef typename TSample::Iterator Iterator;
00583 typedef typename TSample::ConstIterator ConstIterator;
00584
00585 Iterator Begin()
00586 {
00587 typename TSample::ConstIterator iter = m_Sample->Begin();
00588 return iter;
00589 }
00590
00591 Iterator End()
00592 {
00593 Iterator iter = m_Sample->End();
00594 return iter;
00595 }
00596
00597 ConstIterator Begin() const
00598 {
00599 typename TSample::ConstIterator iter = m_Sample->Begin();
00600 return iter;
00601 }
00602
00603 ConstIterator End() const
00604 {
00605 ConstIterator iter = m_Sample->End();
00606 return iter;
00607 }
00608
00609 protected:
00611 KdTree();
00612
00614 virtual ~KdTree();
00615
00616 void PrintSelf(std::ostream& os, Indent indent) const;
00617
00619 int NearestNeighborSearchLoop(const KdTreeNodeType* node,
00620 const MeasurementVectorType &query,
00621 MeasurementVectorType &lowerBound,
00622 MeasurementVectorType &upperBound) const;
00623
00625 int SearchLoop(const KdTreeNodeType* node, const MeasurementVectorType &query,
00626 MeasurementVectorType &lowerBound,
00627 MeasurementVectorType &upperBound) const;
00628 private:
00629 KdTree(const Self&);
00630 void operator=(const Self&);
00632
00634 const TSample* m_Sample;
00635
00637 int m_BucketSize;
00638
00640 KdTreeNodeType* m_Root;
00641
00643 KdTreeNodeType* m_EmptyTerminalNode;
00644
00646 typename DistanceMetricType::Pointer m_DistanceMetric;
00647
00648 mutable bool m_IsNearestNeighborSearch;
00649
00650 mutable double m_SearchRadius;
00651
00652 mutable InstanceIdentifierVectorType m_Neighbors;
00653
00655 mutable NearestNeighbors m_NearestNeighbors;
00656
00658 mutable MeasurementVectorType m_LowerBound;
00659
00661 mutable MeasurementVectorType m_UpperBound;
00662
00664 mutable int m_NumberOfVisits;
00665
00667 mutable bool m_StopSearch;
00668
00670 mutable NeighborType m_TempNeighbor;
00671
00673 MeasurementVectorSizeType m_MeasurementVectorSize;
00674 };
00675
00676 }
00677 }
00678
00679 #ifndef ITK_MANUAL_INSTANTIATION
00680 #include "itkKdTree.txx"
00681 #endif
00682
00683 #endif
00684