Numerics/Statistics/itkKdTree.h
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkKdTree_h
00018 #define __itkKdTree_h
00019
00020 #include <queue>
00021 #include <vector>
00022
00023 #include "itkMacro.h"
00024 #include "itkPoint.h"
00025 #include "itkSize.h"
00026 #include "itkObject.h"
00027 #include "itkNumericTraits.h"
00028 #include "itkArray.h"
00029
00030 #include "itkSample.h"
00031 #include "itkSubsample.h"
00032
00033 #include "itkEuclideanDistance.h"
00034
00035 namespace itk {
00036 namespace Statistics {
00037
00065 template< class TSample >
00066 struct KdTreeNode
00067 {
00069 typedef KdTreeNode< TSample> Self;
00070
00072 typedef typename TSample::MeasurementType MeasurementType;
00073
00075 typedef Array< double > CentroidType;
00076
00079 typedef typename TSample::InstanceIdentifier InstanceIdentifier;
00080
00083 virtual bool IsTerminal() const = 0;
00084
00090 virtual void GetParameters(unsigned int &partitionDimension,
00091 MeasurementType &partitionValue) const = 0;
00092
00094 virtual Self* Left() = 0;
00095 virtual const Self* Left() const = 0;
00097
00099 virtual Self* Right() = 0;
00100 virtual const Self* Right() const = 0;
00102
00105 virtual unsigned int Size() const = 0;
00106
00108 virtual void GetWeightedCentroid(CentroidType ¢roid) = 0;
00109
00111 virtual void GetCentroid(CentroidType ¢roid) = 0;
00112
00114 virtual InstanceIdentifier GetInstanceIdentifier(size_t index) const = 0;
00115
00117 virtual void AddInstanceIdentifier(InstanceIdentifier id) = 0;
00118
00120 virtual ~KdTreeNode() {};
00121 };
00122
00134 template< class TSample >
00135 struct KdTreeNonterminalNode: public KdTreeNode< TSample >
00136 {
00137 typedef KdTreeNode< TSample > Superclass;
00138 typedef typename Superclass::MeasurementType MeasurementType;
00139 typedef typename Superclass::CentroidType CentroidType;
00140 typedef typename Superclass::InstanceIdentifier InstanceIdentifier;
00141
00142 KdTreeNonterminalNode(unsigned int partitionDimension,
00143 MeasurementType partitionValue,
00144 Superclass* left,
00145 Superclass* right);
00146
00147 virtual ~KdTreeNonterminalNode() {}
00148
00149 virtual bool IsTerminal() const
00150 { return false; }
00151
00152 void GetParameters(unsigned int &partitionDimension,
00153 MeasurementType &partitionValue) const;
00154
00155 Superclass* Left()
00156 { return m_Left; }
00157
00158 Superclass* Right()
00159 { return m_Right; }
00160
00161 const Superclass* Left() const
00162 { return m_Left; }
00163
00164 const Superclass* Right() const
00165 { return m_Right; }
00166
00167 unsigned int Size() const
00168 { return 0; }
00169
00170 void GetWeightedCentroid( CentroidType & )
00171 {}
00172
00173 void GetCentroid( CentroidType & )
00174 {}
00175
00176
00177
00178
00179 InstanceIdentifier GetInstanceIdentifier(size_t) const
00180 { return this->m_InstanceIdentifier; }
00181
00182 void AddInstanceIdentifier(InstanceIdentifier valueId)
00183 { this->m_InstanceIdentifier = valueId; }
00184
00185 private:
00186 unsigned int m_PartitionDimension;
00187 MeasurementType m_PartitionValue;
00188 InstanceIdentifier m_InstanceIdentifier;
00189 Superclass* m_Left;
00190 Superclass* m_Right;
00191 };
00192
00207 template< class TSample >
00208 struct KdTreeWeightedCentroidNonterminalNode: public KdTreeNode< TSample >
00209 {
00210 typedef KdTreeNode< TSample > Superclass;
00211 typedef typename Superclass::MeasurementType MeasurementType;
00212 typedef typename Superclass::CentroidType CentroidType;
00213 typedef typename Superclass::InstanceIdentifier InstanceIdentifier;
00214 typedef typename TSample::MeasurementVectorSizeType MeasurementVectorSizeType;
00215
00216 KdTreeWeightedCentroidNonterminalNode(unsigned int partitionDimension,
00217 MeasurementType partitionValue,
00218 Superclass* left,
00219 Superclass* right,
00220 CentroidType ¢roid,
00221 unsigned int size);
00222 virtual ~KdTreeWeightedCentroidNonterminalNode()
00223 {
00224 }
00225
00226
00227 virtual bool IsTerminal() const
00228 { return false; }
00229
00230 void GetParameters(unsigned int &partitionDimension,
00231 MeasurementType &partitionValue) const;
00232
00234 MeasurementVectorSizeType GetMeasurementVectorSize() const
00235 {
00236 return m_MeasurementVectorSize;
00237 }
00238
00239 Superclass* Left()
00240 { return m_Left; }
00241
00242 Superclass* Right()
00243 { return m_Right; }
00244
00245 const Superclass* Left() const
00246 { return m_Left; }
00247
00248 const Superclass* Right() const
00249 { return m_Right; }
00250
00251 unsigned int Size() const
00252 { return m_Size; }
00253
00254 void GetWeightedCentroid(CentroidType ¢roid)
00255 { centroid = m_WeightedCentroid; }
00256
00257 void GetCentroid(CentroidType ¢roid)
00258 { centroid = m_Centroid; }
00259
00260 InstanceIdentifier GetInstanceIdentifier(size_t) const
00261 { return this->m_InstanceIdentifier; }
00262
00263 void AddInstanceIdentifier(InstanceIdentifier valueId)
00264 { this->m_InstanceIdentifier = valueId; }
00265
00266 private:
00267 MeasurementVectorSizeType m_MeasurementVectorSize;
00268 unsigned int m_PartitionDimension;
00269 MeasurementType m_PartitionValue;
00270 CentroidType m_WeightedCentroid;
00271 CentroidType m_Centroid;
00272 InstanceIdentifier m_InstanceIdentifier;
00273 unsigned int m_Size;
00274 Superclass* m_Left;
00275 Superclass* m_Right;
00276 };
00277
00278
00290 template< class TSample >
00291 struct KdTreeTerminalNode: public KdTreeNode< TSample >
00292 {
00293 typedef KdTreeNode< TSample > Superclass;
00294 typedef typename Superclass::MeasurementType MeasurementType;
00295 typedef typename Superclass::CentroidType CentroidType;
00296 typedef typename Superclass::InstanceIdentifier InstanceIdentifier;
00297
00298 KdTreeTerminalNode() {}
00299
00300 virtual ~KdTreeTerminalNode()
00301 {
00302 this->m_InstanceIdentifiers.clear();
00303 }
00304
00305 bool IsTerminal() const
00306 { return true; }
00307
00308 void GetParameters(unsigned int &,
00309 MeasurementType &) const {}
00310
00311 Superclass* Left()
00312 { return 0; }
00313
00314 Superclass* Right()
00315 { return 0; }
00316
00317 const Superclass* Left() const
00318 { return 0; }
00319
00320 const Superclass* Right() const
00321 { return 0; }
00322
00323 unsigned int Size() const
00324 { return static_cast<unsigned int>( m_InstanceIdentifiers.size() ); }
00325
00326 void GetWeightedCentroid(CentroidType &)
00327 {}
00328
00329 void GetCentroid(CentroidType &)
00330 {}
00331
00332 InstanceIdentifier GetInstanceIdentifier(size_t index) const
00333 { return m_InstanceIdentifiers[index]; }
00334
00335 void AddInstanceIdentifier(InstanceIdentifier id)
00336 { m_InstanceIdentifiers.push_back(id);}
00337
00338 private:
00339 std::vector< InstanceIdentifier > m_InstanceIdentifiers;
00340 };
00341
00374 template < class TSample >
00375 class ITK_EXPORT KdTree : public Object
00376 {
00377 public:
00379 typedef KdTree Self;
00380 typedef Object Superclass;
00381 typedef SmartPointer<Self> Pointer;
00382 typedef SmartPointer<const Self> ConstPointer;
00383
00385 itkTypeMacro(KdTree, Object);
00386
00388 itkNewMacro(Self);
00389
00391 typedef TSample SampleType;
00392 typedef typename TSample::MeasurementVectorType MeasurementVectorType;
00393 typedef typename TSample::MeasurementType MeasurementType;
00394 typedef typename TSample::InstanceIdentifier InstanceIdentifier;
00395 typedef typename TSample::FrequencyType FrequencyType;
00396
00397 typedef unsigned int MeasurementVectorSizeType;
00398
00401 itkGetConstMacro( MeasurementVectorSize, MeasurementVectorSizeType );
00402
00404 typedef EuclideanDistance< MeasurementVectorType > DistanceMetricType;
00405
00407 typedef KdTreeNode< TSample > KdTreeNodeType;
00408
00412 typedef std::pair< InstanceIdentifier, double > NeighborType;
00413
00414 typedef std::vector< InstanceIdentifier > InstanceIdentifierVectorType;
00415
00425 class NearestNeighbors
00426 {
00427 public:
00429 NearestNeighbors() {}
00430
00432 ~NearestNeighbors() {}
00433
00436 void resize(unsigned int k)
00437 {
00438 m_Identifiers.clear();
00439 m_Identifiers.resize(k, NumericTraits< unsigned long >::max());
00440 m_Distances.clear();
00441 m_Distances.resize(k, NumericTraits< double >::max());
00442 m_FarthestNeighborIndex = 0;
00443 }
00445
00447 double GetLargestDistance()
00448 { return m_Distances[m_FarthestNeighborIndex]; }
00449
00452 void ReplaceFarthestNeighbor(InstanceIdentifier id, double distance)
00453 {
00454 m_Identifiers[m_FarthestNeighborIndex] = id;
00455 m_Distances[m_FarthestNeighborIndex] = distance;
00456 double farthestDistance = NumericTraits< double >::min();
00457 const unsigned int size = static_cast<unsigned int>( m_Distances.size() );
00458 for ( unsigned int i = 0; i < size; i++ )
00459 {
00460 if ( m_Distances[i] > farthestDistance )
00461 {
00462 farthestDistance = m_Distances[i];
00463 m_FarthestNeighborIndex = i;
00464 }
00465 }
00466 }
00468
00470 const InstanceIdentifierVectorType & GetNeighbors() const
00471 { return m_Identifiers; }
00472
00475 InstanceIdentifier GetNeighbor(unsigned int index) const
00476 { return m_Identifiers[index]; }
00477
00479 const std::vector< double >& GetDistances() const
00480 { return m_Distances; }
00481
00482 private:
00484 unsigned int m_FarthestNeighborIndex;
00485
00487 InstanceIdentifierVectorType m_Identifiers;
00488
00491 std::vector< double > m_Distances;
00492 };
00493
00496 void SetBucketSize(unsigned int size);
00497
00500 void SetSample(const TSample* sample);
00501
00503 const TSample* GetSample() const
00504 { return m_Sample; }
00505
00506 unsigned long Size() const
00507 { return m_Sample->Size(); }
00508
00513 KdTreeNodeType* GetEmptyTerminalNode()
00514 { return m_EmptyTerminalNode; }
00515
00518 void SetRoot(KdTreeNodeType* root)
00519 {
00520 if( this->m_Root )
00521 {
00522 this->DeleteNode( this->m_Root );
00523 }
00524 this->m_Root = root;
00525 }
00527
00529 KdTreeNodeType* GetRoot()
00530 { return m_Root; }
00531
00534 const MeasurementVectorType & GetMeasurementVector(InstanceIdentifier id) const
00535 { return m_Sample->GetMeasurementVector(id); }
00536
00539 FrequencyType GetFrequency(InstanceIdentifier id) const
00540 { return m_Sample->GetFrequency( id ); }
00541
00543 DistanceMetricType* GetDistanceMetric()
00544 { return m_DistanceMetric.GetPointer(); }
00545
00547 void Search(const MeasurementVectorType &query,
00548 unsigned int numberOfNeighborsRequested,
00549 InstanceIdentifierVectorType& result) const;
00550
00552 void Search(const MeasurementVectorType &query,
00553 double radius,
00554 InstanceIdentifierVectorType& result) const;
00555
00558 int GetNumberOfVisits() const
00559 { return m_NumberOfVisits; }
00560
00566 bool BallWithinBounds(const MeasurementVectorType &query,
00567 MeasurementVectorType &lowerBound,
00568 MeasurementVectorType &upperBound,
00569 double radius) const;
00570
00574 bool BoundsOverlapBall(const MeasurementVectorType &query,
00575 MeasurementVectorType &lowerBound,
00576 MeasurementVectorType &upperBound,
00577 double radius) const;
00578
00580 void DeleteNode(KdTreeNodeType *node);
00581
00583 void PrintTree( std::ostream & os ) const;
00584
00586 void PrintTree(KdTreeNodeType *node, unsigned int level,
00587 unsigned int activeDimension,
00588 std::ostream & os = std::cout ) const;
00589
00592 void PlotTree( std::ostream & os ) const;
00593
00595 void PlotTree(KdTreeNodeType *node, std::ostream & os = std::cout ) const;
00596
00597
00598 typedef typename TSample::Iterator Iterator;
00599 typedef typename TSample::ConstIterator ConstIterator;
00600
00601 Iterator Begin()
00602 {
00603 typename TSample::ConstIterator iter = m_Sample->Begin();
00604 return iter;
00605 }
00606
00607 Iterator End()
00608 {
00609 Iterator iter = m_Sample->End();
00610 return iter;
00611 }
00612
00613 ConstIterator Begin() const
00614 {
00615 typename TSample::ConstIterator iter = m_Sample->Begin();
00616 return iter;
00617 }
00618
00619 ConstIterator End() const
00620 {
00621 ConstIterator iter = m_Sample->End();
00622 return iter;
00623 }
00624
00625 protected:
00627 KdTree();
00628
00630 virtual ~KdTree();
00631
00632 void PrintSelf(std::ostream& os, Indent indent) const;
00633
00635 int NearestNeighborSearchLoop(const KdTreeNodeType* node,
00636 const MeasurementVectorType &query,
00637 MeasurementVectorType &lowerBound,
00638 MeasurementVectorType &upperBound) const;
00639
00641 int SearchLoop(const KdTreeNodeType* node, const MeasurementVectorType &query,
00642 MeasurementVectorType &lowerBound,
00643 MeasurementVectorType &upperBound) const;
00644 private:
00645 KdTree(const Self&);
00646 void operator=(const Self&);
00648
00650 const TSample* m_Sample;
00651
00653 int m_BucketSize;
00654
00656 KdTreeNodeType* m_Root;
00657
00659 KdTreeNodeType* m_EmptyTerminalNode;
00660
00662 typename DistanceMetricType::Pointer m_DistanceMetric;
00663
00664 mutable bool m_IsNearestNeighborSearch;
00665
00666 mutable double m_SearchRadius;
00667
00668 mutable InstanceIdentifierVectorType m_Neighbors;
00669
00671 mutable NearestNeighbors m_NearestNeighbors;
00672
00674 mutable MeasurementVectorType m_LowerBound;
00675
00677 mutable MeasurementVectorType m_UpperBound;
00678
00680 mutable int m_NumberOfVisits;
00681
00683 mutable bool m_StopSearch;
00684
00686 mutable NeighborType m_TempNeighbor;
00687
00689 MeasurementVectorSizeType m_MeasurementVectorSize;
00690 };
00691
00692 }
00693 }
00694
00695 #ifndef ITK_MANUAL_INSTANTIATION
00696 #include "itkKdTree.txx"
00697 #endif
00698
00699 #endif
00700