ITK  5.0.0
Insight Segmentation and Registration Toolkit
itkKdTreeBasedKmeansEstimator.h
Go to the documentation of this file.
1 /*=========================================================================
2  *
3  * Copyright Insight Software Consortium
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  *=========================================================================*/
18 #ifndef itkKdTreeBasedKmeansEstimator_h
19 #define itkKdTreeBasedKmeansEstimator_h
20 
21 #include <vector>
22 #include "itksys/hash_map.hxx"
23 
24 #include "itkObject.h"
29 
30 namespace itk
31 {
32 namespace Statistics
33 {
76 template< typename TKdTree >
77 class ITK_TEMPLATE_EXPORT KdTreeBasedKmeansEstimator:
78  public Object
79 {
80 public:
83  using Superclass = Object;
86 
88  itkNewMacro(Self);
89 
91  itkTypeMacro(KdTreeBasedKmeansEstimator, Object);
92 
94  using KdTreeNodeType = typename TKdTree::KdTreeNodeType;
95  using MeasurementType = typename TKdTree::MeasurementType;
96  using MeasurementVectorType = typename TKdTree::MeasurementVectorType;
97  using InstanceIdentifier = typename TKdTree::InstanceIdentifier;
98  using SampleType = typename TKdTree::SampleType;
99  using CentroidType = typename KdTreeNodeType::CentroidType;
100 
102  using MeasurementVectorSizeType = unsigned int;
103 
107  using InternalParametersType = std::vector< ParameterType >;
109 
113 
115 
118  using MembershipFunctionVectorType = std::vector< MembershipFunctionPointer >;
121 
124  const MembershipFunctionVectorObjectType * GetOutput() const;
125 
127  itkSetMacro(Parameters, ParametersType);
128  itkGetConstMacro(Parameters, ParametersType);
130 
132  itkSetMacro(MaximumIteration, int);
133  itkGetConstMacro(MaximumIteration, int);
135 
138  itkSetMacro(CentroidPositionChangesThreshold, double);
139  itkGetConstMacro(CentroidPositionChangesThreshold, double);
140 
142  void SetKdTree(TKdTree *tree);
143 
144  const TKdTree * GetKdTree() const;
145 
147  itkGetConstMacro(MeasurementVectorSize, MeasurementVectorSizeType);
148 
149  itkGetConstMacro(CurrentIteration, int);
150  itkGetConstMacro(CentroidPositionChanges, double);
151 
156  void StartOptimization();
157 
158  using ClusterLabelsType = itksys::hash_map< InstanceIdentifier, unsigned int >;
159 
160  itkSetMacro(UseClusterLabels, bool);
161  itkGetConstMacro(UseClusterLabels, bool);
162 
163 protected:
165  ~KdTreeBasedKmeansEstimator() override = default;
166 
167  void PrintSelf(std::ostream & os, Indent indent) const override;
168 
169  void FillClusterLabels(KdTreeNodeType *node, int closestIndex);
170 
176  {
177 public:
179 
180  struct Candidate {
183  int Size;
184  }; // end of struct
185 
186  virtual ~CandidateVector() = default;
187 
189  int Size() const
190  {
191  return static_cast< int >( m_Candidates.size() );
192  }
193 
197  {
198  this->m_MeasurementVectorSize = NumericTraits<ParameterType>::GetLength(centroids[0]);
199  m_Candidates.resize( centroids.size() );
200  for ( unsigned int i = 0; i < centroids.size(); i++ )
201  {
202  Candidate candidate;
203  candidate.Centroid = centroids[i];
205  m_MeasurementVectorSize);
206  candidate.WeightedCentroid.Fill(0.0);
207  candidate.Size = 0;
208  m_Candidates[i] = candidate;
209  }
210  }
212 
215  {
216  unsigned int i;
217 
218  centroids.resize( this->Size() );
219  for ( i = 0; i < (unsigned int)this->Size(); i++ )
220  {
221  centroids[i] = m_Candidates[i].Centroid;
222  }
223  }
224 
228  {
229  unsigned int i, j;
230 
231  for ( i = 0; i < (unsigned int)this->Size(); i++ )
232  {
233  if ( m_Candidates[i].Size > 0 )
234  {
235  for ( j = 0; j < m_MeasurementVectorSize; j++ )
236  {
237  m_Candidates[i].Centroid[j] =
238  m_Candidates[i].WeightedCentroid[j]
239  / double(m_Candidates[i].Size);
240  }
241  }
242  }
243  }
244 
246  Candidate & operator[](int index)
247  {
248  return m_Candidates[index];
249  }
250 
251 private:
253  std::vector< Candidate > m_Candidates;
254 
256  MeasurementVectorSizeType m_MeasurementVectorSize{0};
257  }; // end of class
258 
264  double GetSumOfSquaredPositionChanges(InternalParametersType & previous,
265  InternalParametersType & current);
266 
269  int GetClosestCandidate(ParameterType & measurements,
270  std::vector< int > & validIndexes);
271 
273  bool IsFarther(ParameterType & pointA,
274  ParameterType & pointB,
275  MeasurementVectorType & lowerBound,
276  MeasurementVectorType & upperBound);
277 
280  void Filter(KdTreeNodeType *node,
281  std::vector< int > validIndexes,
282  MeasurementVectorType & lowerBound,
283  MeasurementVectorType & upperBound);
284 
286  void CopyParameters(InternalParametersType & source, InternalParametersType & target);
287 
289  void CopyParameters(ParametersType & source, InternalParametersType & target);
290 
292  void CopyParameters(InternalParametersType & source, ParametersType & target);
293 
295  void GetPoint(ParameterType & point, MeasurementVectorType measurements);
296 
297  void PrintPoint(ParameterType & point);
298 
299 private:
301  int m_CurrentIteration{0};
302 
304  int m_MaximumIteration{100};
305 
307  double m_CentroidPositionChanges{0.0};
308 
311  double m_CentroidPositionChangesThreshold{0.0};
312 
314  typename TKdTree::Pointer m_KdTree;
315 
318 
321 
323 
325 
326  bool m_UseClusterLabels{false};
327  bool m_GenerateClusterLabels{false};
329  MeasurementVectorSizeType m_MeasurementVectorSize{0};
331 }; // end of class
332 } // end of namespace Statistics
333 } // end of namespace itk
334 
335 #ifndef ITK_MANUAL_INSTANTIATION
336 #include "itkKdTreeBasedKmeansEstimator.hxx"
337 #endif
338 
339 #endif
EuclideanDistanceMetric< ParameterType >::Pointer m_DistanceMetric
Light weight base class for most itk classes.
typename DistanceToCentroidMembershipFunctionType::Pointer DistanceToCentroidMembershipFunctionPointer
Define numeric traits for std::vector.
typename TKdTree::InstanceIdentifier InstanceIdentifier
itksys::hash_map< InstanceIdentifier, unsigned int > ClusterLabelsType
fast k-means algorithm implementation using k-d tree structure
typename MembershipFunctionVectorObjectType::Pointer MembershipFunctionVectorObjectPointer
typename KdTreeNodeType::CentroidType CentroidType
MembershipFunctionVectorObjectPointer m_MembershipFunctionsObject
Decorates any &quot;simple&quot; data type (data types without smart pointers) with a DataObject API...
std::vector< MembershipFunctionPointer > MembershipFunctionVectorType
MembershipFunctionBase defines common interfaces for membership functions.
Represent a n-dimensional size (bounds) of a n-dimensional image.
Definition: itkSize.h:68
typename TKdTree::MeasurementVectorType MeasurementVectorType
DistanceToCentroidMembershipFunction models class membership using a distance metric.
Control indentation during Print() invocation.
Definition: itkIndent.h:49
Base class for most ITK classes.
Definition: itkObject.h:60
typename MembershipFunctionType::ConstPointer MembershipFunctionPointer