ITK  4.0.0
Insight Segmentation and Registration Toolkit
itkKdTreeBasedKmeansEstimator.h
Go to the documentation of this file.
00001 /*=========================================================================
00002  *
00003  *  Copyright Insight Software Consortium
00004  *
00005  *  Licensed under the Apache License, Version 2.0 (the "License");
00006  *  you may not use this file except in compliance with the License.
00007  *  You may obtain a copy of the License at
00008  *
00009  *         http://www.apache.org/licenses/LICENSE-2.0.txt
00010  *
00011  *  Unless required by applicable law or agreed to in writing, software
00012  *  distributed under the License is distributed on an "AS IS" BASIS,
00013  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00014  *  See the License for the specific language governing permissions and
00015  *  limitations under the License.
00016  *
00017  *=========================================================================*/
00018 #ifndef __itkKdTreeBasedKmeansEstimator_h
00019 #define __itkKdTreeBasedKmeansEstimator_h
00020 
00021 #include <vector>
00022 #include "itksys/hash_map.hxx"
00023 
00024 #include "itkObject.h"
00025 #include "itkEuclideanDistanceMetric.h"
00026 #include "itkDistanceToCentroidMembershipFunction.h"
00027 #include "itkSimpleDataObjectDecorator.h"
00028 #include "itkNumericTraitsArrayPixel.h"
00029 
00030 namespace itk
00031 {
00032 namespace Statistics
00033 {
00076 template< class TKdTree >
00077 class ITK_EXPORT KdTreeBasedKmeansEstimator:
00078   public Object
00079 {
00080 public:
00082   typedef KdTreeBasedKmeansEstimator Self;
00083   typedef Object                     Superclass;
00084   typedef SmartPointer< Self >       Pointer;
00085   typedef SmartPointer< const Self > ConstPointer;
00086 
00088   itkNewMacro(Self);
00089 
00091   itkTypeMacro(KdTreeBasedKmeansEstimator, Object);
00092 
00094   typedef typename TKdTree::KdTreeNodeType        KdTreeNodeType;
00095   typedef typename TKdTree::MeasurementType       MeasurementType;
00096   typedef typename TKdTree::MeasurementVectorType MeasurementVectorType;
00097   typedef typename TKdTree::InstanceIdentifier    InstanceIdentifier;
00098   typedef typename TKdTree::SampleType            SampleType;
00099   typedef typename KdTreeNodeType::CentroidType   CentroidType;
00100 
00102   typedef unsigned int MeasurementVectorSizeType;
00103 
00106   typedef Array< double >              ParameterType;
00107   typedef std::vector< ParameterType > InternalParametersType;
00108   typedef Array< double >              ParametersType;
00109 
00112   typedef DistanceToCentroidMembershipFunction< MeasurementVectorType >
00113   DistanceToCentroidMembershipFunctionType;
00114 
00115   typedef typename DistanceToCentroidMembershipFunctionType::Pointer
00116   DistanceToCentroidMembershipFunctionPointer;
00117 
00118   typedef MembershipFunctionBase< MeasurementVectorType > MembershipFunctionType;
00119   typedef typename MembershipFunctionType::ConstPointer   MembershipFunctionPointer;
00120   typedef std::vector< MembershipFunctionPointer >        MembershipFunctionVectorType;
00121   typedef SimpleDataObjectDecorator<
00122     MembershipFunctionVectorType >                        MembershipFunctionVectorObjectType;
00123   typedef typename
00124   MembershipFunctionVectorObjectType::Pointer MembershipFunctionVectorObjectPointer;
00125 
00128   const MembershipFunctionVectorObjectType * GetOutput() const;
00129 
00131   itkSetMacro(Parameters, ParametersType);
00132   itkGetConstMacro(Parameters, ParametersType);
00134 
00136   itkSetMacro(MaximumIteration, int);
00137   itkGetConstMacro(MaximumIteration, int);
00139 
00142   itkSetMacro(CentroidPositionChangesThreshold, double);
00143   itkGetConstMacro(CentroidPositionChangesThreshold, double);
00144 
00146   void SetKdTree(TKdTree *tree);
00147 
00148   const TKdTree * GetKdTree() const;
00149 
00151   itkGetConstMacro(MeasurementVectorSize, MeasurementVectorSizeType);
00152 
00153   itkGetConstMacro(CurrentIteration, int);
00154   itkGetConstMacro(CentroidPositionChanges, double);
00155 
00160   void StartOptimization();
00161 
00162   typedef itksys::hash_map< InstanceIdentifier, unsigned int > ClusterLabelsType;
00163 
00164   itkSetMacro(UseClusterLabels, bool);
00165   itkGetConstMacro(UseClusterLabels, bool);
00166 protected:
00167   KdTreeBasedKmeansEstimator();
00168   virtual ~KdTreeBasedKmeansEstimator() {}
00169 
00170   void PrintSelf(std::ostream & os, Indent indent) const;
00171 
00172   void FillClusterLabels(KdTreeNodeType *node, int closestIndex);
00173 
00178   class CandidateVector
00179   {
00180 public:
00181     CandidateVector() {}
00182 
00183     struct Candidate {
00184       CentroidType Centroid;
00185       CentroidType WeightedCentroid;
00186       int Size;
00187     };   // end of struct
00188 
00189     virtual ~CandidateVector() {}
00190 
00192     int Size() const
00193     {
00194       return static_cast< int >( m_Candidates.size() );
00195     }
00196 
00199     void SetCentroids(InternalParametersType & centroids)
00200     {
00201       this->m_MeasurementVectorSize = NumericTraits<ParameterType>::GetLength(centroids[0]);
00202       m_Candidates.resize( centroids.size() );
00203       for ( unsigned int i = 0; i < centroids.size(); i++ )
00204         {
00205         Candidate candidate;
00206         candidate.Centroid = centroids[i];
00207         NumericTraits<CentroidType>::SetLength(candidate.WeightedCentroid,
00208           m_MeasurementVectorSize);
00209         candidate.WeightedCentroid.Fill(0.0);
00210         candidate.Size = 0;
00211         m_Candidates[i] = candidate;
00212         }
00213     }
00215 
00217     void GetCentroids(InternalParametersType & centroids)
00218     {
00219       unsigned int i;
00220 
00221       centroids.resize( this->Size() );
00222       for ( i = 0; i < (unsigned int)this->Size(); i++ )
00223         {
00224         centroids[i] = m_Candidates[i].Centroid;
00225         }
00226     }
00227 
00230     void UpdateCentroids()
00231     {
00232       unsigned int i, j;
00233 
00234       for ( i = 0; i < (unsigned int)this->Size(); i++ )
00235         {
00236         if ( m_Candidates[i].Size > 0 )
00237           {
00238           for ( j = 0; j < m_MeasurementVectorSize; j++ )
00239             {
00240             m_Candidates[i].Centroid[j] =
00241               m_Candidates[i].WeightedCentroid[j]
00242               / double(m_Candidates[i].Size);
00243             }
00244           }
00245         }
00246     }
00247 
00249     Candidate & operator[](int index)
00250     {
00251       return m_Candidates[index];
00252     }
00253 
00254 private:
00256     std::vector< Candidate > m_Candidates;
00257 
00259     MeasurementVectorSizeType m_MeasurementVectorSize;
00260   };  // end of class
00261 
00267   double GetSumOfSquaredPositionChanges(InternalParametersType & previous,
00268                                         InternalParametersType & current);
00269 
00272   int GetClosestCandidate(ParameterType & measurements,
00273                           std::vector< int > & validIndexes);
00274 
00276   bool IsFarther(ParameterType & pointA,
00277                  ParameterType & pointB,
00278                  MeasurementVectorType & lowerBound,
00279                  MeasurementVectorType & upperBound);
00280 
00283   void Filter(KdTreeNodeType *node,
00284               std::vector< int > validIndexes,
00285               MeasurementVectorType & lowerBound,
00286               MeasurementVectorType & upperBound);
00287 
00289   void CopyParameters(InternalParametersType & source, InternalParametersType & target);
00290 
00292   void CopyParameters(ParametersType & source, InternalParametersType & target);
00293 
00295   void CopyParameters(InternalParametersType & source, ParametersType & target);
00296 
00298   void GetPoint(ParameterType & point, MeasurementVectorType measurements);
00299 
00300   void PrintPoint(ParameterType & point);
00301 
00302 private:
00304   int m_CurrentIteration;
00305 
00307   int m_MaximumIteration;
00308 
00310   double m_CentroidPositionChanges;
00311 
00314   double m_CentroidPositionChangesThreshold;
00315 
00317   typename TKdTree::Pointer m_KdTree;
00318 
00320   typename EuclideanDistanceMetric< ParameterType >::Pointer m_DistanceMetric;
00321 
00323   ParametersType m_Parameters;
00324 
00325   CandidateVector m_CandidateVector;
00326 
00327   ParameterType m_TempVertex;
00328 
00329   bool                                  m_UseClusterLabels;
00330   bool                                  m_GenerateClusterLabels;
00331   ClusterLabelsType                     m_ClusterLabels;
00332   MeasurementVectorSizeType             m_MeasurementVectorSize;
00333   MembershipFunctionVectorObjectPointer m_MembershipFunctionsObject;
00334 };  // end of class
00335 } // end of namespace Statistics
00336 } // end of namespace itk
00337 
00338 #ifndef ITK_MANUAL_INSTANTIATION
00339 #include "itkKdTreeBasedKmeansEstimator.hxx"
00340 #endif
00341 
00342 #endif
00343