ITK  4.13.0
Insight Segmentation and Registration Toolkit
itkTrainingFunctionBase.h
Go to the documentation of this file.
1 /*=========================================================================
2  *
3  * Copyright Insight Software Consortium
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  *=========================================================================*/
18 #ifndef itkTrainingFunctionBase_h
19 #define itkTrainingFunctionBase_h
20 
21 #include <iostream>
22 #include "itkLightProcessObject.h"
23 #include "itkNeuralNetworkObject.h"
26 namespace itk
27 {
28 namespace Statistics
29 {
36 template<typename TSample, typename TTargetVector, typename ScalarType>
37 class ITK_TEMPLATE_EXPORT TrainingFunctionBase : public LightProcessObject
38 {
39 public:
44 
47 
49  itkNewMacro(Self);
50 
51  typedef ScalarType ValueType;
52  typedef typename TSample::MeasurementVectorType VectorType;
53  typedef typename TTargetVector::MeasurementVectorType OutputVectorType;
55 
56  typedef std::vector<VectorType> InputSampleVectorType;
57  typedef std::vector<OutputVectorType> OutputSampleVectorType;
62 
63  void SetTrainingSamples(TSample* samples);
64  void SetTargetValues(TTargetVector* targets);
65  void SetLearningRate(ValueType);
66 
67  ValueType GetLearningRate();
68 
69  itkSetMacro(Iterations, SizeValueType);
70  itkGetConstReferenceMacro(Iterations, SizeValueType);
71 
72  void SetPerformanceFunction(PerformanceFunctionType* f);
73 
74  virtual void Train(NetworkType* itkNotUsed(net), TSample* itkNotUsed(samples), TTargetVector* itkNotUsed(targets))
75  {
76  // not implemented
77  };
78 
79  inline VectorType
80  defaultconverter(typename TSample::MeasurementVectorType v)
81  {
82  VectorType temp;
83  for (unsigned int i = 0; i < v.Size(); i++)
84  {
85  temp[i] = static_cast<ScalarType>(v[i]);
86  }
87  return temp;
88  }
89 
90  inline OutputVectorType
91  targetconverter(typename TTargetVector::MeasurementVectorType v)
92  {
93  OutputVectorType temp;
94 
95  for (unsigned int i = 0; i < v.Size(); i++)
96  {
97  temp[i] = static_cast<ScalarType>(v[i]);
98  }
99  return temp;
100  }
101 
102 protected:
103 
105  ~TrainingFunctionBase() ITK_OVERRIDE {}
106 
108  virtual void PrintSelf( std::ostream& os, Indent indent ) const ITK_OVERRIDE;
109 
110  TSample* m_TrainingSamples;// original samples
111  TTargetVector* m_SampleTargets; // original samples
116 
118 };
119 
120 } // end namespace Statistics
121 } // end namespace itk
122 #ifndef ITK_MANUAL_INSTANTIATION
123 #include "itkTrainingFunctionBase.hxx"
124 #endif
125 
126 #endif
Array class with size defined at construction time.
Definition: itkArray.h:50
Light weight base class for most itk classes.
ErrorFunctionBase< InternalVectorType, ScalarType > PerformanceFunctionType
VectorType defaultconverter(typename TSample::MeasurementVectorType v)
OutputVectorType targetconverter(typename TTargetVector::MeasurementVectorType v)
unsigned long SizeValueType
Definition: itkIntTypes.h:143
This is the itkErrorFunctionBase class.
std::vector< OutputVectorType > OutputSampleVectorType
std::vector< VectorType > InputSampleVectorType
PerformanceFunctionType::Pointer m_PerformanceFunction
SquaredDifferenceErrorFunction< InternalVectorType, ScalarType > DefaultPerformanceType
TSample::MeasurementVectorType VectorType
This is the itkNeuralNetworkObject class.
This is the itkTrainingFunctionBase class.
NeuralNetworkObject< VectorType, OutputVectorType > NetworkType
LightProcessObject is the base class for all process objects (source, filters, mappers) in the Insigh...
Control indentation during Print() invocation.
Definition: itkIndent.h:49
virtual void Train(NetworkType *, TSample *, TTargetVector *)
TTargetVector::MeasurementVectorType OutputVectorType
This is the itkSquaredDifferenceErrorFunction class.