00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkKdTreeBasedKmeansEstimator_h
00018 #define __itkKdTreeBasedKmeansEstimator_h
00019
00020 #include <vector>
00021 #include "itk_hash_map.h"
00022
00023 #include "itkObject.h"
00024
00025 namespace itk {
00026 namespace Statistics {
00027
00054 template< class TKdTree >
00055 class ITK_EXPORT KdTreeBasedKmeansEstimator:
00056 public Object
00057 {
00058 public:
00060 typedef KdTreeBasedKmeansEstimator Self ;
00061 typedef Object Superclass;
00062 typedef SmartPointer<Self> Pointer;
00063 typedef SmartPointer<const Self> ConstPointer;
00064
00066 itkNewMacro(Self);
00067
00069 itkTypeMacro(KdTreeBasedKmeansEstimator, Obeject);
00070
00072 typedef typename TKdTree::KdTreeNodeType KdTreeNodeType ;
00073 typedef typename TKdTree::MeasurementType MeasurementType ;
00074 typedef typename TKdTree::MeasurementVectorType MeasurementVectorType ;
00075 typedef typename TKdTree::InstanceIdentifier InstanceIdentifier ;
00076 typedef typename TKdTree::SampleType SampleType ;
00077 typedef typename KdTreeNodeType::CenteroidType CenteroidType ;
00078 itkStaticConstMacro(MeasurementVectorSize, unsigned int,
00079 TKdTree::MeasurementVectorSize);
00082
00083 typedef FixedArray< double, itkGetStaticConstMacro(MeasurementVectorSize) > ParameterType ;
00084 typedef std::vector< ParameterType > InternalParametersType;
00085 typedef Array< double > ParametersType;
00086
00088 void SetParameters(ParametersType& params)
00089 { m_Parameters = params ; }
00090
00092 ParametersType& GetParameters()
00093 { return m_Parameters ; }
00094
00096 itkSetMacro( MaximumIteration, int );
00097 itkGetConstMacro( MaximumIteration, int );
00098
00101 itkSetMacro( CenteroidPositionChangesThreshold, double );
00102 itkGetConstMacro( CenteroidPositionChangesThreshold, double );
00103
00105 void SetKdTree(TKdTree* tree)
00106 { m_KdTree = tree ; }
00107
00108 TKdTree* GetKdTree()
00109 { return m_KdTree ; }
00110
00111 itkGetConstMacro( CurrentIteration, int) ;
00112 itkGetConstMacro( CenteroidPositionChanges, double) ;
00113
00118 void StartOptimization() ;
00119
00120 typedef itk::hash_map< InstanceIdentifier, unsigned int > ClusterLabelsType ;
00121
00122 void SetUseClusterLabels(bool flag)
00123 { m_UseClusterLabels = flag ; }
00124
00125 ClusterLabelsType* GetClusterLabels()
00126 { return &m_ClusterLabels ; }
00127
00128 protected:
00129 KdTreeBasedKmeansEstimator() ;
00130 virtual ~KdTreeBasedKmeansEstimator() {}
00131
00132 void PrintSelf(std::ostream& os, Indent indent) const;
00133
00134 void FillClusterLabels(KdTreeNodeType* node, int closestIndex) ;
00135
00137 class CandidateVector
00138 {
00139 public:
00140 CandidateVector() {}
00141
00142 struct Candidate
00143 {
00144 CenteroidType Centeroid ;
00145 CenteroidType WeightedCenteroid ;
00146 int Size ;
00147 } ;
00148
00149 virtual ~CandidateVector() {}
00150
00152 int Size()
00153 { return m_Candidates.size() ; }
00154
00157 void SetCenteroids(InternalParametersType& centeroids)
00158 {
00159 m_Candidates.resize(centeroids.size()) ;
00160 for (unsigned int i = 0 ; i < centeroids.size() ; i++)
00161 {
00162 Candidate candidate ;
00163 candidate.Centeroid = centeroids[i] ;
00164 candidate.WeightedCenteroid.Fill(0.0) ;
00165 candidate.Size = 0 ;
00166 m_Candidates[i] = candidate ;
00167 }
00168 }
00169
00171 void GetCenteroids(InternalParametersType& centeroids)
00172 {
00173 unsigned int i ;
00174 centeroids.resize(this->Size()) ;
00175 for (i = 0 ; i < (unsigned int)this->Size() ; i++)
00176 {
00177 centeroids[i] = m_Candidates[i].Centeroid ;
00178 }
00179 }
00180
00183 void UpdateCenteroids()
00184 {
00185 unsigned int i, j ;
00186 for (i = 0 ; i < (unsigned int)this->Size() ; i++)
00187 {
00188 if (m_Candidates[i].Size > 0)
00189 {
00190 for (j = 0 ; j < MeasurementVectorSize ; j++)
00191 {
00192 m_Candidates[i].Centeroid[j] =
00193 m_Candidates[i].WeightedCenteroid[j] /
00194 double(m_Candidates[i].Size) ;
00195 }
00196 }
00197 }
00198 }
00199
00201 Candidate& operator[](int index)
00202 { return m_Candidates[index] ; }
00203
00204
00205 private:
00207 std::vector< Candidate > m_Candidates ;
00208 } ;
00209
00215 double GetSumOfSquaredPositionChanges(InternalParametersType &previous,
00216 InternalParametersType ¤t) ;
00217
00220 int GetClosestCandidate(ParameterType &measurements,
00221 std::vector< int > &validIndexes) ;
00222
00224 bool IsFarther(ParameterType &pointA,
00225 ParameterType &pointB,
00226 MeasurementVectorType &lowerBound,
00227 MeasurementVectorType &upperBound) ;
00228
00231 void Filter(KdTreeNodeType* node,
00232 std::vector< int > validIndexes,
00233 MeasurementVectorType &lowerBound,
00234 MeasurementVectorType &upperBound) ;
00235
00237 void CopyParameters(InternalParametersType &source, InternalParametersType &target) ;
00238
00240 void CopyParameters(ParametersType &source, InternalParametersType &target) ;
00241
00243 void CopyParameters(InternalParametersType &source, ParametersType &target) ;
00244
00246 void GetPoint(ParameterType &point,
00247 MeasurementVectorType &measurements)
00248 {
00249 for (unsigned int i = 0 ; i < MeasurementVectorSize ; i++)
00250 {
00251 point[i] = measurements[i] ;
00252 }
00253 }
00254
00255 void PrintPoint(ParameterType &point)
00256 {
00257 std::cout << "[ " ;
00258 for (unsigned int i = 0 ; i < MeasurementVectorSize ; i++)
00259 {
00260 std::cout << point[i] << " " ;
00261 }
00262 std::cout << "]" ;
00263 }
00264
00265 private:
00267 int m_CurrentIteration ;
00269 int m_MaximumIteration ;
00271 double m_CenteroidPositionChanges ;
00274 double m_CenteroidPositionChangesThreshold ;
00276 TKdTree* m_KdTree ;
00278 typename EuclideanDistance< ParameterType >::Pointer m_DistanceMetric ;
00279
00281 ParametersType m_Parameters ;
00282
00283 CandidateVector m_CandidateVector ;
00284
00285 ParameterType m_TempVertex ;
00286
00287 bool m_UseClusterLabels ;
00288 bool m_GenerateClusterLabels ;
00289 ClusterLabelsType m_ClusterLabels ;
00290 } ;
00291
00292 }
00293 }
00294
00295 #ifndef ITK_MANUAL_INSTANTIATION
00296 #include "itkKdTreeBasedKmeansEstimator.txx"
00297 #endif
00298
00299
00300 #endif