Numerics/Statistics/itkKdTreeBasedKmeansEstimator.h
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkKdTreeBasedKmeansEstimator_h
00018 #define __itkKdTreeBasedKmeansEstimator_h
00019
00020 #include <vector>
00021 #include "itk_hash_map.h"
00022
00023 #include "itkObject.h"
00024 #include "itkMeasurementVectorTraits.h"
00025
00026 namespace itk {
00027 namespace Statistics {
00028
00066 template< class TKdTree >
00067 class ITK_EXPORT KdTreeBasedKmeansEstimator:
00068 public Object
00069 {
00070 public:
00072 typedef KdTreeBasedKmeansEstimator Self;
00073 typedef Object Superclass;
00074 typedef SmartPointer<Self> Pointer;
00075 typedef SmartPointer<const Self> ConstPointer;
00076
00078 itkNewMacro(Self);
00079
00081 itkTypeMacro(KdTreeBasedKmeansEstimator, Obeject);
00082
00084 typedef typename TKdTree::KdTreeNodeType KdTreeNodeType;
00085 typedef typename TKdTree::MeasurementType MeasurementType;
00086 typedef typename TKdTree::MeasurementVectorType MeasurementVectorType;
00087 typedef typename TKdTree::InstanceIdentifier InstanceIdentifier;
00088 typedef typename TKdTree::SampleType SampleType;
00089 typedef typename KdTreeNodeType::CentroidType CentroidType;
00090
00091
00093 typedef unsigned int MeasurementVectorSizeType;
00094
00097 typedef Array< double > ParameterType;
00098 typedef std::vector< ParameterType > InternalParametersType;
00099 typedef Array< double > ParametersType;
00100
00102 void SetParameters(ParametersType& params)
00103 { m_Parameters = params; }
00104
00106 ParametersType& GetParameters()
00107 { return m_Parameters; }
00108
00110 itkSetMacro( MaximumIteration, int );
00111 itkGetConstReferenceMacro( MaximumIteration, int );
00113
00116 itkSetMacro( CentroidPositionChangesThreshold, double );
00117 itkGetConstReferenceMacro( CentroidPositionChangesThreshold, double );
00119
00121 void SetKdTree(TKdTree* tree)
00122 {
00123 m_KdTree = tree;
00124 m_MeasurementVectorSize = tree->GetMeasurementVectorSize();
00125 m_DistanceMetric->SetMeasurementVectorSize( m_MeasurementVectorSize );
00126 MeasurementVectorTraits::SetLength( m_TempVertex, m_MeasurementVectorSize );
00127 }
00129
00130 TKdTree* GetKdTree()
00131 { return m_KdTree.GetPointer(); }
00132
00134 itkGetConstReferenceMacro( MeasurementVectorSize, MeasurementVectorSizeType );
00135
00136 itkGetConstReferenceMacro( CurrentIteration, int);
00137 itkGetConstReferenceMacro( CentroidPositionChanges, double);
00138
00143 void StartOptimization();
00144
00145 typedef itk::hash_map< InstanceIdentifier, unsigned int > ClusterLabelsType;
00146
00147 void SetUseClusterLabels(bool flag)
00148 { m_UseClusterLabels = flag; }
00149
00150 ClusterLabelsType* GetClusterLabels()
00151 { return &m_ClusterLabels; }
00152
00153 protected:
00154 KdTreeBasedKmeansEstimator();
00155 virtual ~KdTreeBasedKmeansEstimator() {}
00156
00157 void PrintSelf(std::ostream& os, Indent indent) const;
00158
00159 void FillClusterLabels(KdTreeNodeType* node, int closestIndex);
00160
00163 class CandidateVector
00164 {
00165 public:
00166 CandidateVector() {}
00167
00168 struct Candidate
00169 {
00170 CentroidType Centroid;
00171 CentroidType WeightedCentroid;
00172 int Size;
00173 };
00174
00175 virtual ~CandidateVector() {}
00176
00178 int Size() const
00179 { return static_cast<int>( m_Candidates.size() ); }
00180
00183 void SetCentroids(InternalParametersType& centroids)
00184 {
00185 this->m_MeasurementVectorSize = MeasurementVectorTraits::GetLength( centroids[0] );
00186 m_Candidates.resize(centroids.size());
00187 for (unsigned int i = 0; i < centroids.size(); i++)
00188 {
00189 Candidate candidate;
00190 candidate.Centroid = centroids[i];
00191 MeasurementVectorTraits::SetLength( candidate.WeightedCentroid, m_MeasurementVectorSize );
00192 candidate.WeightedCentroid.Fill(0.0);
00193 candidate.Size = 0;
00194 m_Candidates[i] = candidate;
00195 }
00196 }
00198
00200 void GetCentroids(InternalParametersType& centroids)
00201 {
00202 unsigned int i;
00203 centroids.resize(this->Size());
00204 for (i = 0; i < (unsigned int)this->Size(); i++)
00205 {
00206 centroids[i] = m_Candidates[i].Centroid;
00207 }
00208 }
00210
00213 void UpdateCentroids()
00214 {
00215 unsigned int i, j;
00216 for (i = 0; i < (unsigned int)this->Size(); i++)
00217 {
00218 if (m_Candidates[i].Size > 0)
00219 {
00220 for (j = 0; j < m_MeasurementVectorSize; j++)
00221 {
00222 m_Candidates[i].Centroid[j] =
00223 m_Candidates[i].WeightedCentroid[j] /
00224 double(m_Candidates[i].Size);
00225 }
00226 }
00227 }
00228 }
00230
00232 Candidate& operator[](int index)
00233 { return m_Candidates[index]; }
00234
00235 private:
00237 std::vector< Candidate > m_Candidates;
00238
00240 MeasurementVectorSizeType m_MeasurementVectorSize;
00241 };
00242
00248 double GetSumOfSquaredPositionChanges(InternalParametersType &previous,
00249 InternalParametersType ¤t);
00250
00253 int GetClosestCandidate(ParameterType &measurements,
00254 std::vector< int > &validIndexes);
00255
00257 bool IsFarther(ParameterType &pointA,
00258 ParameterType &pointB,
00259 MeasurementVectorType &lowerBound,
00260 MeasurementVectorType &upperBound);
00261
00264 void Filter(KdTreeNodeType* node,
00265 std::vector< int > validIndexes,
00266 MeasurementVectorType &lowerBound,
00267 MeasurementVectorType &upperBound);
00268
00270 void CopyParameters(InternalParametersType &source, InternalParametersType &target);
00271
00273 void CopyParameters(ParametersType &source, InternalParametersType &target);
00274
00276 void CopyParameters(InternalParametersType &source, ParametersType &target);
00277
00279 void GetPoint(ParameterType &point, MeasurementVectorType measurements)
00280 {
00281 for (unsigned int i = 0; i < m_MeasurementVectorSize; i++)
00282 {
00283 point[i] = measurements[i];
00284 }
00285 }
00287
00288 void PrintPoint(ParameterType &point)
00289 {
00290 std::cout << "[ ";
00291 for (unsigned int i = 0; i < m_MeasurementVectorSize; i++)
00292 {
00293 std::cout << point[i] << " ";
00294 }
00295 std::cout << "]";
00296 }
00297
00298 private:
00300 int m_CurrentIteration;
00301
00303 int m_MaximumIteration;
00304
00306 double m_CentroidPositionChanges;
00307
00310 double m_CentroidPositionChangesThreshold;
00311
00313 typename TKdTree::Pointer m_KdTree;
00314
00316 typename EuclideanDistance< ParameterType >::Pointer m_DistanceMetric;
00317
00319 ParametersType m_Parameters;
00320
00321 CandidateVector m_CandidateVector;
00322
00323 ParameterType m_TempVertex;
00324
00325 bool m_UseClusterLabels;
00326 bool m_GenerateClusterLabels;
00327 ClusterLabelsType m_ClusterLabels;
00328 MeasurementVectorSizeType m_MeasurementVectorSize;
00329 };
00330
00331 }
00332 }
00333
00334 #ifndef ITK_MANUAL_INSTANTIATION
00335 #include "itkKdTreeBasedKmeansEstimator.txx"
00336 #endif
00337
00338
00339 #endif
00340