00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkKdTreeBasedKmeansEstimator_h
00018 #define __itkKdTreeBasedKmeansEstimator_h
00019
00020 #include <vector>
00021 #include "itk_hash_map.h"
00022
00023 #include "itkObject.h"
00024 #include "itkMeasurementVectorTraits.h"
00025 #include "itkEuclideanDistanceMetric.h"
00026 #include "itkDistanceToCentroidMembershipFunction.h"
00027 #include "itkSimpleDataObjectDecorator.h"
00028
00029 namespace itk {
00030 namespace Statistics {
00031
00069 template< class TKdTree >
00070 class ITK_EXPORT KdTreeBasedKmeansEstimator:
00071 public Object
00072 {
00073 public:
00075 typedef KdTreeBasedKmeansEstimator Self;
00076 typedef Object Superclass;
00077 typedef SmartPointer<Self> Pointer;
00078 typedef SmartPointer<const Self> ConstPointer;
00079
00081 itkNewMacro(Self);
00082
00084 itkTypeMacro(KdTreeBasedKmeansEstimator, Obeject);
00085
00087 typedef typename TKdTree::KdTreeNodeType KdTreeNodeType;
00088 typedef typename TKdTree::MeasurementType MeasurementType;
00089 typedef typename TKdTree::MeasurementVectorType MeasurementVectorType;
00090 typedef typename TKdTree::InstanceIdentifier InstanceIdentifier;
00091 typedef typename TKdTree::SampleType SampleType;
00092 typedef typename KdTreeNodeType::CentroidType CentroidType;
00093
00094
00096 typedef unsigned int MeasurementVectorSizeType;
00097
00100 typedef Array< double > ParameterType;
00101 typedef std::vector< ParameterType > InternalParametersType;
00102 typedef Array< double > ParametersType;
00103
00104
00107 typedef DistanceToCentroidMembershipFunction< MeasurementVectorType >
00108 DistanceToCentroidMembershipFunctionType;
00109
00110 typedef typename DistanceToCentroidMembershipFunctionType::Pointer
00111 DistanceToCentroidMembershipFunctionPointer;
00112
00113 typedef MembershipFunctionBase< MeasurementVectorType > MembershipFunctionType;
00114 typedef typename MembershipFunctionType::ConstPointer MembershipFunctionPointer;
00115 typedef std::vector< MembershipFunctionPointer > MembershipFunctionVectorType;
00116 typedef SimpleDataObjectDecorator<
00117 MembershipFunctionVectorType > MembershipFunctionVectorObjectType;
00118 typedef typename
00119 MembershipFunctionVectorObjectType::Pointer MembershipFunctionVectorObjectPointer;
00120
00123 const MembershipFunctionVectorObjectType * GetOutput() const;
00124
00126 itkSetMacro( Parameters, ParametersType );
00127 itkGetConstMacro( Parameters, ParametersType );
00129
00131 itkSetMacro( MaximumIteration, int );
00132 itkGetConstMacro( MaximumIteration, int );
00134
00137 itkSetMacro( CentroidPositionChangesThreshold, double );
00138 itkGetConstMacro( CentroidPositionChangesThreshold, double );
00139
00141 void SetKdTree(TKdTree* tree);
00142 const TKdTree* GetKdTree() const;
00144
00146 itkGetConstMacro( MeasurementVectorSize, MeasurementVectorSizeType );
00147
00148 itkGetConstMacro( CurrentIteration, int);
00149 itkGetConstMacro( CentroidPositionChanges, double);
00150
00155 void StartOptimization();
00156
00157 typedef itk::hash_map< InstanceIdentifier, unsigned int > ClusterLabelsType;
00158
00159 itkSetMacro( UseClusterLabels, bool );
00160 itkGetConstMacro( UseClusterLabels, bool );
00161
00162 protected:
00163 KdTreeBasedKmeansEstimator();
00164 virtual ~KdTreeBasedKmeansEstimator() {}
00165
00166 void PrintSelf(std::ostream& os, Indent indent) const;
00167
00168 void FillClusterLabels(KdTreeNodeType* node, int closestIndex);
00169
00171 class CandidateVector {
00172 public:
00173 CandidateVector() {}
00174
00175 struct Candidate {
00176 CentroidType Centroid;
00177 CentroidType WeightedCentroid;
00178 int Size;
00179 };
00180
00181 virtual ~CandidateVector() {}
00182
00184 int Size() const
00185 {
00186 return static_cast<int>( m_Candidates.size() );
00187 }
00188
00191 void SetCentroids(InternalParametersType& centroids)
00192 {
00193 this->m_MeasurementVectorSize = MeasurementVectorTraits::GetLength( centroids[0] );
00194 m_Candidates.resize(centroids.size());
00195 for (unsigned int i = 0; i < centroids.size(); i++)
00196 {
00197 Candidate candidate;
00198 candidate.Centroid = centroids[i];
00199 MeasurementVectorTraits::SetLength( candidate.WeightedCentroid, m_MeasurementVectorSize );
00200 candidate.WeightedCentroid.Fill(0.0);
00201 candidate.Size = 0;
00202 m_Candidates[i] = candidate;
00203 }
00204 }
00206
00208 void GetCentroids(InternalParametersType& centroids)
00209 {
00210 unsigned int i;
00211 centroids.resize(this->Size());
00212 for (i = 0; i < (unsigned int)this->Size(); i++)
00213 {
00214 centroids[i] = m_Candidates[i].Centroid;
00215 }
00216 }
00218
00221 void UpdateCentroids()
00222 {
00223 unsigned int i, j;
00224 for (i = 0; i < (unsigned int)this->Size(); i++)
00225 {
00226 if (m_Candidates[i].Size > 0)
00227 {
00228 for (j = 0; j < m_MeasurementVectorSize; j++)
00229 {
00230 m_Candidates[i].Centroid[j] =
00231 m_Candidates[i].WeightedCentroid[j] /
00232 double(m_Candidates[i].Size);
00233 }
00234 }
00235 }
00236 }
00238
00240 Candidate& operator[](int index)
00241 {
00242 return m_Candidates[index];
00243 }
00244
00245 private:
00247 std::vector< Candidate > m_Candidates;
00248
00250 MeasurementVectorSizeType m_MeasurementVectorSize;
00251 };
00252
00258 double GetSumOfSquaredPositionChanges(InternalParametersType &previous,
00259 InternalParametersType ¤t);
00260
00263 int GetClosestCandidate(ParameterType &measurements,
00264 std::vector< int > &validIndexes);
00265
00267 bool IsFarther(ParameterType &pointA,
00268 ParameterType &pointB,
00269 MeasurementVectorType &lowerBound,
00270 MeasurementVectorType &upperBound);
00271
00274 void Filter(KdTreeNodeType* node,
00275 std::vector< int > validIndexes,
00276 MeasurementVectorType &lowerBound,
00277 MeasurementVectorType &upperBound);
00278
00280 void CopyParameters(InternalParametersType &source, InternalParametersType &target);
00281
00283 void CopyParameters(ParametersType &source, InternalParametersType &target);
00284
00286 void CopyParameters(InternalParametersType &source, ParametersType &target);
00287
00289 void GetPoint(ParameterType &point, MeasurementVectorType measurements);
00290
00291 void PrintPoint(ParameterType &point);
00292
00293 private:
00295 int m_CurrentIteration;
00296
00298 int m_MaximumIteration;
00299
00301 double m_CentroidPositionChanges;
00302
00305 double m_CentroidPositionChangesThreshold;
00306
00308 typename TKdTree::Pointer m_KdTree;
00309
00311 typename EuclideanDistanceMetric< ParameterType >::Pointer m_DistanceMetric;
00312
00314 ParametersType m_Parameters;
00315
00316 CandidateVector m_CandidateVector;
00317
00318 ParameterType m_TempVertex;
00319
00320 bool m_UseClusterLabels;
00321 bool m_GenerateClusterLabels;
00322 ClusterLabelsType m_ClusterLabels;
00323 MeasurementVectorSizeType m_MeasurementVectorSize;
00324 MembershipFunctionVectorObjectPointer m_MembershipFunctionsObject;
00325
00326 };
00327
00328 }
00329
00330 }
00331
00332 #ifndef ITK_MANUAL_INSTANTIATION
00333 #include "itkKdTreeBasedKmeansEstimator.txx"
00334 #endif
00335
00336 #endif
00337