00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkKdTree_h
00018 #define __itkKdTree_h
00019
00020 #include <queue>
00021 #include <vector>
00022
00023 #include "itkMacro.h"
00024 #include "itkPoint.h"
00025 #include "itkSize.h"
00026 #include "itkObject.h"
00027 #include "itkNumericTraits.h"
00028 #include "itkArray.h"
00029
00030 #include "itkSample.h"
00031 #include "itkSubsample.h"
00032
00033 #include "itkEuclideanDistance.h"
00034
00035 namespace itk{
00036 namespace Statistics{
00037
00062 template< class TSample >
00063 struct KdTreeNode
00064 {
00066 typedef KdTreeNode< TSample> Self ;
00067
00069 typedef typename TSample::MeasurementType MeasurementType ;
00070
00072 typedef Array< double > CentroidType;
00073
00076 typedef typename TSample::InstanceIdentifier InstanceIdentifier ;
00077
00080 virtual bool IsTerminal() const = 0 ;
00081
00087 virtual void GetParameters(unsigned int &partitionDimension,
00088 MeasurementType &partitionValue) const = 0 ;
00089
00091 virtual Self* Left() = 0 ;
00092 virtual const Self* Left() const = 0 ;
00094
00096 virtual Self* Right() = 0 ;
00097 virtual const Self* Right() const = 0 ;
00099
00102 virtual unsigned int Size() const = 0 ;
00103
00105 virtual void GetWeightedCentroid(CentroidType ¢roid) = 0 ;
00106
00108 virtual void GetCentroid(CentroidType ¢roid) = 0 ;
00109
00111 virtual InstanceIdentifier GetInstanceIdentifier(size_t index) const = 0 ;
00112
00114 virtual void AddInstanceIdentifier(InstanceIdentifier id) = 0 ;
00115
00117 virtual ~KdTreeNode() {};
00118 } ;
00119
00131 template< class TSample >
00132 struct KdTreeNonterminalNode: public KdTreeNode< TSample >
00133 {
00134 typedef KdTreeNode< TSample > Superclass ;
00135 typedef typename Superclass::MeasurementType MeasurementType ;
00136 typedef typename Superclass::CentroidType CentroidType ;
00137 typedef typename Superclass::InstanceIdentifier InstanceIdentifier ;
00138
00139 KdTreeNonterminalNode(unsigned int partitionDimension,
00140 MeasurementType partitionValue,
00141 Superclass* left,
00142 Superclass* right) ;
00143
00144 virtual ~KdTreeNonterminalNode() {}
00145
00146 virtual bool IsTerminal() const
00147 { return false ; }
00148
00149 void GetParameters(unsigned int &partitionDimension,
00150 MeasurementType &partitionValue) const;
00151
00152 Superclass* Left()
00153 { return m_Left ; }
00154
00155 Superclass* Right()
00156 { return m_Right ; }
00157
00158 const Superclass* Left() const
00159 { return m_Left ; }
00160
00161 const Superclass* Right() const
00162 { return m_Right ; }
00163
00164 unsigned int Size() const
00165 { return 0 ; }
00166
00167 void GetWeightedCentroid(CentroidType &)
00168 { }
00169
00170 void GetCentroid(CentroidType &)
00171 { }
00172
00173 InstanceIdentifier GetInstanceIdentifier(size_t) const
00174 { return 0 ; }
00175
00176 void AddInstanceIdentifier(InstanceIdentifier) {}
00177
00178 private:
00179 unsigned int m_PartitionDimension ;
00180 MeasurementType m_PartitionValue ;
00181 Superclass* m_Left ;
00182 Superclass* m_Right ;
00183 } ;
00184
00199 template< class TSample >
00200 struct KdTreeWeightedCentroidNonterminalNode: public KdTreeNode< TSample >
00201 {
00202 typedef KdTreeNode< TSample > Superclass ;
00203 typedef typename Superclass::MeasurementType MeasurementType ;
00204 typedef typename Superclass::CentroidType CentroidType ;
00205 typedef typename Superclass::InstanceIdentifier InstanceIdentifier ;
00206 typedef typename TSample::MeasurementVectorSizeType MeasurementVectorSizeType;
00207
00208 KdTreeWeightedCentroidNonterminalNode(unsigned int partitionDimension,
00209 MeasurementType partitionValue,
00210 Superclass* left,
00211 Superclass* right,
00212 CentroidType ¢roid,
00213 unsigned int size) ;
00214 virtual ~KdTreeWeightedCentroidNonterminalNode() {}
00215
00216 virtual bool IsTerminal() const
00217 { return false ; }
00218
00219 void GetParameters(unsigned int &partitionDimension,
00220 MeasurementType &partitionValue) const ;
00221
00223 MeasurementVectorSizeType GetMeasurementVectorSize() const
00224 {
00225 return m_MeasurementVectorSize;
00226 }
00227
00228 Superclass* Left()
00229 { return m_Left ; }
00230
00231 Superclass* Right()
00232 { return m_Right ; }
00233
00234
00235 const Superclass* Left() const
00236 { return m_Left ; }
00237
00238 const Superclass* Right() const
00239 { return m_Right ; }
00240
00241 unsigned int Size() const
00242 { return m_Size ; }
00243
00244 void GetWeightedCentroid(CentroidType ¢roid)
00245 { centroid = m_WeightedCentroid ; }
00246
00247 void GetCentroid(CentroidType ¢roid)
00248 { centroid = m_Centroid ; }
00249
00250 InstanceIdentifier GetInstanceIdentifier(size_t) const
00251 { return 0 ; }
00252
00253 void AddInstanceIdentifier(InstanceIdentifier) {}
00254
00255 private:
00256 MeasurementVectorSizeType m_MeasurementVectorSize;
00257 unsigned int m_PartitionDimension ;
00258 MeasurementType m_PartitionValue ;
00259 CentroidType m_WeightedCentroid ;
00260 CentroidType m_Centroid ;
00261 unsigned int m_Size ;
00262 Superclass* m_Left ;
00263 Superclass* m_Right ;
00264 } ;
00265
00266
00278 template< class TSample >
00279 struct KdTreeTerminalNode: public KdTreeNode< TSample >
00280 {
00281 typedef KdTreeNode< TSample > Superclass ;
00282 typedef typename Superclass::MeasurementType MeasurementType ;
00283 typedef typename Superclass::CentroidType CentroidType ;
00284 typedef typename Superclass::InstanceIdentifier InstanceIdentifier ;
00285
00286 KdTreeTerminalNode() {}
00287
00288 virtual ~KdTreeTerminalNode() {}
00289
00290 bool IsTerminal() const
00291 { return true ; }
00292
00293 void GetParameters(unsigned int &,
00294 MeasurementType &) const {}
00295
00296 Superclass* Left()
00297 { return 0 ; }
00298
00299 Superclass* Right()
00300 { return 0 ; }
00301
00302
00303 const Superclass* Left() const
00304 { return 0 ; }
00305
00306 const Superclass* Right() const
00307 { return 0 ; }
00308
00309 unsigned int Size() const
00310 { return static_cast<unsigned int>( m_InstanceIdentifiers.size() ); }
00311
00312 void GetWeightedCentroid(CentroidType &)
00313 { }
00314
00315 void GetCentroid(CentroidType &)
00316 { }
00317
00318 InstanceIdentifier GetInstanceIdentifier(size_t index) const
00319 { return m_InstanceIdentifiers[index] ; }
00320
00321 void AddInstanceIdentifier(InstanceIdentifier id)
00322 { m_InstanceIdentifiers.push_back(id) ;}
00323
00324 private:
00325 std::vector< InstanceIdentifier > m_InstanceIdentifiers ;
00326 } ;
00327
00360 template < class TSample >
00361 class ITK_EXPORT KdTree : public Object
00362 {
00363 public:
00365 typedef KdTree Self ;
00366 typedef Object Superclass ;
00367 typedef SmartPointer<Self> Pointer;
00368 typedef SmartPointer<const Self> ConstPointer;
00369
00371 itkTypeMacro(KdTree, Object);
00372
00374 itkNewMacro(Self) ;
00375
00377 typedef TSample SampleType ;
00378 typedef typename TSample::MeasurementVectorType MeasurementVectorType ;
00379 typedef typename TSample::MeasurementType MeasurementType ;
00380 typedef typename TSample::InstanceIdentifier InstanceIdentifier ;
00381 typedef typename TSample::FrequencyType FrequencyType ;
00382
00383 typedef unsigned int MeasurementVectorSizeType;
00384
00387 itkGetConstMacro( MeasurementVectorSize, MeasurementVectorSizeType );
00388
00390 typedef EuclideanDistance< MeasurementVectorType > DistanceMetricType ;
00391
00393 typedef KdTreeNode< TSample > KdTreeNodeType ;
00394
00398 typedef std::pair< InstanceIdentifier, double > NeighborType ;
00399
00400 typedef std::vector< InstanceIdentifier > InstanceIdentifierVectorType ;
00401
00411 class NearestNeighbors
00412 {
00413 public:
00415 NearestNeighbors() {}
00416
00418 ~NearestNeighbors() {}
00419
00422 void resize(unsigned int k)
00423 {
00424 m_Identifiers.clear() ;
00425 m_Identifiers.resize(k, NumericTraits< unsigned long >::max()) ;
00426 m_Distances.clear() ;
00427 m_Distances.resize(k, NumericTraits< double >::max()) ;
00428 m_FarthestNeighborIndex = 0 ;
00429 }
00431
00433 double GetLargestDistance()
00434 { return m_Distances[m_FarthestNeighborIndex] ; }
00435
00438 void ReplaceFarthestNeighbor(InstanceIdentifier id, double distance)
00439 {
00440 m_Identifiers[m_FarthestNeighborIndex] = id ;
00441 m_Distances[m_FarthestNeighborIndex] = distance ;
00442 double farthestDistance = NumericTraits< double >::min() ;
00443 const unsigned int size = static_cast<unsigned int>( m_Distances.size() );
00444 for ( unsigned int i = 0 ; i < size; i++ )
00445 {
00446 if ( m_Distances[i] > farthestDistance )
00447 {
00448 farthestDistance = m_Distances[i] ;
00449 m_FarthestNeighborIndex = i ;
00450 }
00451 }
00452 }
00454
00456 InstanceIdentifierVectorType GetNeighbors()
00457 { return m_Identifiers ; }
00458
00461 InstanceIdentifier GetNeighbor(unsigned int index)
00462 { return m_Identifiers[index] ; }
00463
00465 std::vector< double >& GetDistances()
00466 { return m_Distances ; }
00467
00468 private:
00470 unsigned int m_FarthestNeighborIndex ;
00471
00473 InstanceIdentifierVectorType m_Identifiers ;
00474
00477 std::vector< double > m_Distances ;
00478 } ;
00479
00482 void SetBucketSize(unsigned int size) ;
00483
00486 void SetSample(const TSample* sample) ;
00487
00489 const TSample* GetSample() const
00490 { return m_Sample ; }
00491
00492 unsigned long Size() const
00493 { return m_Sample->Size() ; }
00494
00499 KdTreeNodeType* GetEmptyTerminalNode()
00500 { return m_EmptyTerminalNode ; }
00501
00504 void SetRoot(KdTreeNodeType* root)
00505 { m_Root = root ; }
00506
00508 KdTreeNodeType* GetRoot()
00509 { return m_Root ; }
00510
00513 const MeasurementVectorType & GetMeasurementVector(InstanceIdentifier id) const
00514 { return m_Sample->GetMeasurementVector(id) ; }
00515
00518 FrequencyType GetFrequency(InstanceIdentifier id) const
00519 { return m_Sample->GetFrequency( id ) ; }
00520
00522 DistanceMetricType* GetDistanceMetric()
00523 { return m_DistanceMetric.GetPointer() ; }
00524
00526 void Search(MeasurementVectorType &query,
00527 unsigned int k,
00528 InstanceIdentifierVectorType& result) const;
00529
00531 void Search(MeasurementVectorType &query,
00532 double radius,
00533 InstanceIdentifierVectorType& result) const;
00534
00537 int GetNumberOfVisits() const
00538 { return m_NumberOfVisits ; }
00539
00545 bool BallWithinBounds(MeasurementVectorType &query,
00546 MeasurementVectorType &lowerBound,
00547 MeasurementVectorType &upperBound,
00548 double radius) const ;
00549
00553 bool BoundsOverlapBall(MeasurementVectorType &query,
00554 MeasurementVectorType &lowerBound,
00555 MeasurementVectorType &upperBound,
00556 double radius) const ;
00557
00559 void DeleteNode(KdTreeNodeType *node) ;
00560
00562 void PrintTree(KdTreeNodeType *node, int level,
00563 unsigned int activeDimension) ;
00564
00565 typedef typename TSample::Iterator Iterator ;
00566 typedef typename TSample::ConstIterator ConstIterator ;
00567
00568 Iterator Begin()
00569 {
00570 typename TSample::ConstIterator iter = m_Sample->Begin() ;
00571 return iter;
00572 }
00573
00574 Iterator End()
00575 {
00576 Iterator iter = m_Sample->End() ;
00577 return iter;
00578 }
00579
00580 ConstIterator Begin() const
00581 {
00582 typename TSample::ConstIterator iter = m_Sample->Begin() ;
00583 return iter;
00584 }
00585
00586 ConstIterator End() const
00587 {
00588 ConstIterator iter = m_Sample->End() ;
00589 return iter;
00590 }
00591
00592
00593 protected:
00595 KdTree() ;
00596
00598 virtual ~KdTree() ;
00599
00600 void PrintSelf(std::ostream& os, Indent indent) const ;
00601
00603 int NearestNeighborSearchLoop(const KdTreeNodeType* node,
00604 MeasurementVectorType &query,
00605 MeasurementVectorType &lowerBound,
00606 MeasurementVectorType &upperBound) const;
00607
00609 int SearchLoop(const KdTreeNodeType* node, MeasurementVectorType &query,
00610 MeasurementVectorType &lowerBound,
00611 MeasurementVectorType &upperBound) const ;
00612 private:
00613 KdTree(const Self&) ;
00614 void operator=(const Self&) ;
00616
00618 const TSample* m_Sample ;
00619
00621 int m_BucketSize ;
00622
00624 KdTreeNodeType* m_Root ;
00625
00627 KdTreeNodeType* m_EmptyTerminalNode ;
00628
00630 typename DistanceMetricType::Pointer m_DistanceMetric ;
00631
00632 mutable bool m_IsNearestNeighborSearch ;
00633
00634 mutable double m_SearchRadius ;
00635
00636 mutable InstanceIdentifierVectorType m_Neighbors ;
00637
00639 mutable NearestNeighbors m_NearestNeighbors ;
00640
00642 mutable MeasurementVectorType m_LowerBound ;
00643
00645 mutable MeasurementVectorType m_UpperBound ;
00646
00648 mutable int m_NumberOfVisits ;
00649
00651 mutable bool m_StopSearch ;
00652
00654 mutable NeighborType m_TempNeighbor ;
00655
00657 MeasurementVectorSizeType m_MeasurementVectorSize;
00658 } ;
00659
00660 }
00661 }
00662
00663 #ifndef ITK_MANUAL_INSTANTIATION
00664 #include "itkKdTree.txx"
00665 #endif
00666
00667 #endif
00668
00669
00670
00671