ITK
4.1.0
Insight Segmentation and Registration Toolkit
|
00001 /*========================================================================= 00002 * 00003 * Copyright Insight Software Consortium 00004 * 00005 * Licensed under the Apache License, Version 2.0 (the "License"); 00006 * you may not use this file except in compliance with the License. 00007 * You may obtain a copy of the License at 00008 * 00009 * http://www.apache.org/licenses/LICENSE-2.0.txt 00010 * 00011 * Unless required by applicable law or agreed to in writing, software 00012 * distributed under the License is distributed on an "AS IS" BASIS, 00013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00014 * See the License for the specific language governing permissions and 00015 * limitations under the License. 00016 * 00017 *=========================================================================*/ 00018 #ifndef __itkKdTreeBasedKmeansEstimator_h 00019 #define __itkKdTreeBasedKmeansEstimator_h 00020 00021 #include <vector> 00022 #include "itksys/hash_map.hxx" 00023 00024 #include "itkObject.h" 00025 #include "itkEuclideanDistanceMetric.h" 00026 #include "itkDistanceToCentroidMembershipFunction.h" 00027 #include "itkSimpleDataObjectDecorator.h" 00028 #include "itkNumericTraitsArrayPixel.h" 00029 00030 namespace itk 00031 { 00032 namespace Statistics 00033 { 00076 template< class TKdTree > 00077 class ITK_EXPORT KdTreeBasedKmeansEstimator: 00078 public Object 00079 { 00080 public: 00082 typedef KdTreeBasedKmeansEstimator Self; 00083 typedef Object Superclass; 00084 typedef SmartPointer< Self > Pointer; 00085 typedef SmartPointer< const Self > ConstPointer; 00086 00088 itkNewMacro(Self); 00089 00091 itkTypeMacro(KdTreeBasedKmeansEstimator, Object); 00092 00094 typedef typename TKdTree::KdTreeNodeType KdTreeNodeType; 00095 typedef typename TKdTree::MeasurementType MeasurementType; 00096 typedef typename TKdTree::MeasurementVectorType MeasurementVectorType; 00097 typedef typename TKdTree::InstanceIdentifier InstanceIdentifier; 00098 typedef typename TKdTree::SampleType SampleType; 00099 typedef typename KdTreeNodeType::CentroidType CentroidType; 00100 00102 typedef unsigned int MeasurementVectorSizeType; 00103 00106 typedef Array< double > ParameterType; 00107 typedef std::vector< ParameterType > InternalParametersType; 00108 typedef Array< double > ParametersType; 00109 00112 typedef DistanceToCentroidMembershipFunction< MeasurementVectorType > 00113 DistanceToCentroidMembershipFunctionType; 00114 00115 typedef typename DistanceToCentroidMembershipFunctionType::Pointer 00116 DistanceToCentroidMembershipFunctionPointer; 00117 00118 typedef MembershipFunctionBase< MeasurementVectorType > MembershipFunctionType; 00119 typedef typename MembershipFunctionType::ConstPointer MembershipFunctionPointer; 00120 typedef std::vector< MembershipFunctionPointer > MembershipFunctionVectorType; 00121 typedef SimpleDataObjectDecorator< 00122 MembershipFunctionVectorType > MembershipFunctionVectorObjectType; 00123 typedef typename 00124 MembershipFunctionVectorObjectType::Pointer MembershipFunctionVectorObjectPointer; 00125 00128 const MembershipFunctionVectorObjectType * GetOutput() const; 00129 00131 itkSetMacro(Parameters, ParametersType); 00132 itkGetConstMacro(Parameters, ParametersType); 00134 00136 itkSetMacro(MaximumIteration, int); 00137 itkGetConstMacro(MaximumIteration, int); 00139 00142 itkSetMacro(CentroidPositionChangesThreshold, double); 00143 itkGetConstMacro(CentroidPositionChangesThreshold, double); 00144 00146 void SetKdTree(TKdTree *tree); 00147 00148 const TKdTree * GetKdTree() const; 00149 00151 itkGetConstMacro(MeasurementVectorSize, MeasurementVectorSizeType); 00152 00153 itkGetConstMacro(CurrentIteration, int); 00154 itkGetConstMacro(CentroidPositionChanges, double); 00155 00160 void StartOptimization(); 00161 00162 typedef itksys::hash_map< InstanceIdentifier, unsigned int > ClusterLabelsType; 00163 00164 itkSetMacro(UseClusterLabels, bool); 00165 itkGetConstMacro(UseClusterLabels, bool); 00166 protected: 00167 KdTreeBasedKmeansEstimator(); 00168 virtual ~KdTreeBasedKmeansEstimator() {} 00169 00170 void PrintSelf(std::ostream & os, Indent indent) const; 00171 00172 void FillClusterLabels(KdTreeNodeType *node, int closestIndex); 00173 00178 class CandidateVector 00179 { 00180 public: 00181 CandidateVector() {} 00182 00183 struct Candidate { 00184 CentroidType Centroid; 00185 CentroidType WeightedCentroid; 00186 int Size; 00187 }; // end of struct 00188 00189 virtual ~CandidateVector() {} 00190 00192 int Size() const 00193 { 00194 return static_cast< int >( m_Candidates.size() ); 00195 } 00196 00199 void SetCentroids(InternalParametersType & centroids) 00200 { 00201 this->m_MeasurementVectorSize = NumericTraits<ParameterType>::GetLength(centroids[0]); 00202 m_Candidates.resize( centroids.size() ); 00203 for ( unsigned int i = 0; i < centroids.size(); i++ ) 00204 { 00205 Candidate candidate; 00206 candidate.Centroid = centroids[i]; 00207 NumericTraits<CentroidType>::SetLength(candidate.WeightedCentroid, 00208 m_MeasurementVectorSize); 00209 candidate.WeightedCentroid.Fill(0.0); 00210 candidate.Size = 0; 00211 m_Candidates[i] = candidate; 00212 } 00213 } 00215 00217 void GetCentroids(InternalParametersType & centroids) 00218 { 00219 unsigned int i; 00220 00221 centroids.resize( this->Size() ); 00222 for ( i = 0; i < (unsigned int)this->Size(); i++ ) 00223 { 00224 centroids[i] = m_Candidates[i].Centroid; 00225 } 00226 } 00227 00230 void UpdateCentroids() 00231 { 00232 unsigned int i, j; 00233 00234 for ( i = 0; i < (unsigned int)this->Size(); i++ ) 00235 { 00236 if ( m_Candidates[i].Size > 0 ) 00237 { 00238 for ( j = 0; j < m_MeasurementVectorSize; j++ ) 00239 { 00240 m_Candidates[i].Centroid[j] = 00241 m_Candidates[i].WeightedCentroid[j] 00242 / double(m_Candidates[i].Size); 00243 } 00244 } 00245 } 00246 } 00247 00249 Candidate & operator[](int index) 00250 { 00251 return m_Candidates[index]; 00252 } 00253 00254 private: 00256 std::vector< Candidate > m_Candidates; 00257 00259 MeasurementVectorSizeType m_MeasurementVectorSize; 00260 }; // end of class 00261 00267 double GetSumOfSquaredPositionChanges(InternalParametersType & previous, 00268 InternalParametersType & current); 00269 00272 int GetClosestCandidate(ParameterType & measurements, 00273 std::vector< int > & validIndexes); 00274 00276 bool IsFarther(ParameterType & pointA, 00277 ParameterType & pointB, 00278 MeasurementVectorType & lowerBound, 00279 MeasurementVectorType & upperBound); 00280 00283 void Filter(KdTreeNodeType *node, 00284 std::vector< int > validIndexes, 00285 MeasurementVectorType & lowerBound, 00286 MeasurementVectorType & upperBound); 00287 00289 void CopyParameters(InternalParametersType & source, InternalParametersType & target); 00290 00292 void CopyParameters(ParametersType & source, InternalParametersType & target); 00293 00295 void CopyParameters(InternalParametersType & source, ParametersType & target); 00296 00298 void GetPoint(ParameterType & point, MeasurementVectorType measurements); 00299 00300 void PrintPoint(ParameterType & point); 00301 00302 private: 00304 int m_CurrentIteration; 00305 00307 int m_MaximumIteration; 00308 00310 double m_CentroidPositionChanges; 00311 00314 double m_CentroidPositionChangesThreshold; 00315 00317 typename TKdTree::Pointer m_KdTree; 00318 00320 typename EuclideanDistanceMetric< ParameterType >::Pointer m_DistanceMetric; 00321 00323 ParametersType m_Parameters; 00324 00325 CandidateVector m_CandidateVector; 00326 00327 ParameterType m_TempVertex; 00328 00329 bool m_UseClusterLabels; 00330 bool m_GenerateClusterLabels; 00331 ClusterLabelsType m_ClusterLabels; 00332 MeasurementVectorSizeType m_MeasurementVectorSize; 00333 MembershipFunctionVectorObjectPointer m_MembershipFunctionsObject; 00334 }; // end of class 00335 } // end of namespace Statistics 00336 } // end of namespace itk 00337 00338 #ifndef ITK_MANUAL_INSTANTIATION 00339 #include "itkKdTreeBasedKmeansEstimator.hxx" 00340 #endif 00341 00342 #endif 00343