00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkKdTreeBasedKmeansEstimator_h
00018 #define __itkKdTreeBasedKmeansEstimator_h
00019
00020 #include <vector>
00021 #include "itk_hash_map.h"
00022
00023 #include "itkObject.h"
00024 #include "itkMeasurementVectorTraits.h"
00025
00026 namespace itk {
00027 namespace Statistics {
00028
00066 template< class TKdTree >
00067 class ITK_EXPORT KdTreeBasedKmeansEstimator:
00068 public Object
00069 {
00070 public:
00072 typedef KdTreeBasedKmeansEstimator Self ;
00073 typedef Object Superclass;
00074 typedef SmartPointer<Self> Pointer;
00075 typedef SmartPointer<const Self> ConstPointer;
00076
00078 itkNewMacro(Self);
00079
00081 itkTypeMacro(KdTreeBasedKmeansEstimator, Obeject);
00082
00084 typedef typename TKdTree::KdTreeNodeType KdTreeNodeType ;
00085 typedef typename TKdTree::MeasurementType MeasurementType ;
00086 typedef typename TKdTree::MeasurementVectorType MeasurementVectorType ;
00087 typedef typename TKdTree::InstanceIdentifier InstanceIdentifier ;
00088 typedef typename TKdTree::SampleType SampleType ;
00089 typedef typename KdTreeNodeType::CentroidType CentroidType ;
00090
00091
00093 typedef unsigned int MeasurementVectorSizeType;
00094
00097 typedef Array< double > ParameterType ;
00098 typedef std::vector< ParameterType > InternalParametersType;
00099 typedef Array< double > ParametersType;
00100
00102 void SetParameters(ParametersType& params)
00103 { m_Parameters = params ; }
00104
00106 ParametersType& GetParameters()
00107 { return m_Parameters ; }
00108
00110 itkSetMacro( MaximumIteration, int );
00111 itkGetConstReferenceMacro( MaximumIteration, int );
00113
00116 itkSetMacro( CentroidPositionChangesThreshold, double );
00117 itkGetConstReferenceMacro( CentroidPositionChangesThreshold, double );
00119
00121 void SetKdTree(TKdTree* tree)
00122 {
00123 m_KdTree = tree ;
00124 m_MeasurementVectorSize = tree->GetMeasurementVectorSize();
00125 m_DistanceMetric->SetMeasurementVectorSize( m_MeasurementVectorSize );
00126 MeasurementVectorTraits::SetLength( m_TempVertex, m_MeasurementVectorSize );
00127 }
00129
00130 TKdTree* GetKdTree()
00131 { return m_KdTree.GetPointer() ; }
00132
00134 itkGetConstReferenceMacro( MeasurementVectorSize, MeasurementVectorSizeType );
00135
00136 itkGetConstReferenceMacro( CurrentIteration, int) ;
00137 itkGetConstReferenceMacro( CentroidPositionChanges, double) ;
00138
00143 void StartOptimization() ;
00144
00145 typedef itk::hash_map< InstanceIdentifier, unsigned int > ClusterLabelsType ;
00146
00147 void SetUseClusterLabels(bool flag)
00148 { m_UseClusterLabels = flag ; }
00149
00150 ClusterLabelsType* GetClusterLabels()
00151 { return &m_ClusterLabels ; }
00152
00153 protected:
00154 KdTreeBasedKmeansEstimator() ;
00155 virtual ~KdTreeBasedKmeansEstimator() {}
00156
00157 void PrintSelf(std::ostream& os, Indent indent) const;
00158
00159 void FillClusterLabels(KdTreeNodeType* node, int closestIndex) ;
00160
00162 class CandidateVector
00163 {
00164 public:
00165 CandidateVector() {}
00166
00167 struct Candidate
00168 {
00169 CentroidType Centroid ;
00170 CentroidType WeightedCentroid ;
00171 int Size ;
00172 } ;
00173
00174 virtual ~CandidateVector() {}
00175
00177 int Size() const
00178 { return static_cast<int>( m_Candidates.size() ); }
00179
00182 void SetCentroids(InternalParametersType& centroids)
00183 {
00184 this->m_MeasurementVectorSize = MeasurementVectorTraits::GetLength( centroids[0] );
00185 m_Candidates.resize(centroids.size()) ;
00186 for (unsigned int i = 0 ; i < centroids.size() ; i++)
00187 {
00188 Candidate candidate ;
00189 candidate.Centroid = centroids[i] ;
00190 MeasurementVectorTraits::SetLength( candidate.WeightedCentroid, m_MeasurementVectorSize );
00191 candidate.WeightedCentroid.Fill(0.0) ;
00192 candidate.Size = 0 ;
00193 m_Candidates[i] = candidate ;
00194 }
00195 }
00197
00199 void GetCentroids(InternalParametersType& centroids)
00200 {
00201 unsigned int i ;
00202 centroids.resize(this->Size()) ;
00203 for (i = 0 ; i < (unsigned int)this->Size() ; i++)
00204 {
00205 centroids[i] = m_Candidates[i].Centroid ;
00206 }
00207 }
00209
00212 void UpdateCentroids()
00213 {
00214 unsigned int i, j ;
00215 for (i = 0 ; i < (unsigned int)this->Size() ; i++)
00216 {
00217 if (m_Candidates[i].Size > 0)
00218 {
00219 for (j = 0 ; j < m_MeasurementVectorSize; j++)
00220 {
00221 m_Candidates[i].Centroid[j] =
00222 m_Candidates[i].WeightedCentroid[j] /
00223 double(m_Candidates[i].Size) ;
00224 }
00225 }
00226 }
00227 }
00229
00231 Candidate& operator[](int index)
00232 { return m_Candidates[index] ; }
00233
00234
00235 private:
00237 std::vector< Candidate > m_Candidates ;
00238
00240 MeasurementVectorSizeType m_MeasurementVectorSize;
00241 } ;
00242
00248 double GetSumOfSquaredPositionChanges(InternalParametersType &previous,
00249 InternalParametersType ¤t) ;
00250
00253 int GetClosestCandidate(ParameterType &measurements,
00254 std::vector< int > &validIndexes) ;
00255
00257 bool IsFarther(ParameterType &pointA,
00258 ParameterType &pointB,
00259 MeasurementVectorType &lowerBound,
00260 MeasurementVectorType &upperBound) ;
00261
00264 void Filter(KdTreeNodeType* node,
00265 std::vector< int > validIndexes,
00266 MeasurementVectorType &lowerBound,
00267 MeasurementVectorType &upperBound) ;
00268
00270 void CopyParameters(InternalParametersType &source, InternalParametersType &target) ;
00271
00273 void CopyParameters(ParametersType &source, InternalParametersType &target) ;
00274
00276 void CopyParameters(InternalParametersType &source, ParametersType &target) ;
00277
00279 void GetPoint(ParameterType &point,
00280 MeasurementVectorType measurements)
00281 {
00282 for (unsigned int i = 0 ; i < m_MeasurementVectorSize ; i++)
00283 {
00284 point[i] = measurements[i] ;
00285 }
00286 }
00288
00289 void PrintPoint(ParameterType &point)
00290 {
00291 std::cout << "[ " ;
00292 for (unsigned int i = 0 ; i < m_MeasurementVectorSize ; i++)
00293 {
00294 std::cout << point[i] << " " ;
00295 }
00296 std::cout << "]" ;
00297 }
00298
00299 private:
00301 int m_CurrentIteration ;
00302
00304 int m_MaximumIteration ;
00305
00307 double m_CentroidPositionChanges ;
00308
00311 double m_CentroidPositionChangesThreshold ;
00312
00314 typename TKdTree::Pointer m_KdTree ;
00315
00317 typename EuclideanDistance< ParameterType >::Pointer m_DistanceMetric ;
00318
00320 ParametersType m_Parameters ;
00321
00322 CandidateVector m_CandidateVector ;
00323
00324 ParameterType m_TempVertex ;
00325
00326 bool m_UseClusterLabels ;
00327 bool m_GenerateClusterLabels ;
00328 ClusterLabelsType m_ClusterLabels ;
00329 MeasurementVectorSizeType m_MeasurementVectorSize;
00330 } ;
00331
00332 }
00333 }
00334
00335 #ifndef ITK_MANUAL_INSTANTIATION
00336 #include "itkKdTreeBasedKmeansEstimator.txx"
00337 #endif
00338
00339
00340 #endif
00341