ITK
4.1.0
Insight Segmentation and Registration Toolkit
|
00001 /*========================================================================= 00002 * 00003 * Copyright Insight Software Consortium 00004 * 00005 * Licensed under the Apache License, Version 2.0 (the "License"); 00006 * you may not use this file except in compliance with the License. 00007 * You may obtain a copy of the License at 00008 * 00009 * http://www.apache.org/licenses/LICENSE-2.0.txt 00010 * 00011 * Unless required by applicable law or agreed to in writing, software 00012 * distributed under the License is distributed on an "AS IS" BASIS, 00013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00014 * See the License for the specific language governing permissions and 00015 * limitations under the License. 00016 * 00017 *=========================================================================*/ 00018 #ifndef __itkTrainingFunctionBase_h 00019 #define __itkTrainingFunctionBase_h 00020 00021 #include <iostream> 00022 #include "itkLightProcessObject.h" 00023 #include "itkNeuralNetworkObject.h" 00024 #include "itkSquaredDifferenceErrorFunction.h" 00025 #include "itkMeanSquaredErrorFunction.h" 00026 namespace itk 00027 { 00028 namespace Statistics 00029 { 00036 template<class TSample, class TTargetVector, class ScalarType> 00037 class TrainingFunctionBase : public LightProcessObject 00038 { 00039 public: 00040 typedef TrainingFunctionBase Self; 00041 typedef LightProcessObject Superclass; 00042 typedef SmartPointer<Self> Pointer; 00043 typedef SmartPointer<const Self> ConstPointer; 00044 00046 itkTypeMacro(TrainingFunctionBase, LightProcessObject); 00047 00049 itkNewMacro(Self); 00050 00051 typedef ScalarType ValueType; 00052 typedef typename TSample::MeasurementVectorType VectorType; 00053 typedef typename TTargetVector::MeasurementVectorType OutputVectorType; 00054 typedef Array<ValueType> InternalVectorType; 00055 00056 typedef std::vector<VectorType> InputSampleVectorType; 00057 typedef std::vector<OutputVectorType> OutputSampleVectorType; 00058 typedef NeuralNetworkObject<VectorType, OutputVectorType> NetworkType; 00059 typedef ErrorFunctionBase<InternalVectorType, ScalarType> PerformanceFunctionType; 00060 typedef SquaredDifferenceErrorFunction<InternalVectorType, ScalarType> 00061 DefaultPerformanceType; 00062 00063 void SetTrainingSamples(TSample* samples); 00064 void SetTargetValues(TTargetVector* targets); 00065 void SetLearningRate(ValueType); 00066 00067 ValueType GetLearningRate(); 00068 00069 itkSetMacro(Iterations, SizeValueType); 00070 itkGetConstReferenceMacro(Iterations, SizeValueType); 00071 00072 void SetPerformanceFunction(PerformanceFunctionType* f); 00073 00074 virtual void Train(NetworkType* itkNotUsed(net), TSample* itkNotUsed(samples), TTargetVector* itkNotUsed(targets)) 00075 { 00076 // not implemented 00077 }; 00078 00079 inline VectorType 00080 defaultconverter(typename TSample::MeasurementVectorType v) 00081 { 00082 VectorType temp; 00083 for (unsigned int i = 0; i < v.Size(); i++) 00084 { 00085 temp[i] = static_cast<ScalarType>(v[i]); 00086 } 00087 return temp; 00088 } 00089 00090 inline OutputVectorType 00091 targetconverter(typename TTargetVector::MeasurementVectorType v) 00092 { 00093 OutputVectorType temp; 00094 00095 for (unsigned int i = 0; i < v.Size(); i++) 00096 { 00097 temp[i] = static_cast<ScalarType>(v[i]); 00098 } 00099 return temp; 00100 } 00101 00102 protected: 00103 00104 TrainingFunctionBase(); 00105 ~TrainingFunctionBase(){}; 00106 00108 virtual void PrintSelf( std::ostream& os, Indent indent ) const; 00109 00110 TSample* m_TrainingSamples;// original samples 00111 TTargetVector* m_SampleTargets; // original samples 00112 InputSampleVectorType m_InputSamples; // itk::vectors 00113 OutputSampleVectorType m_Targets; // itk::vectors 00114 SizeValueType m_Iterations; 00115 ValueType m_LearningRate; 00116 00117 typename PerformanceFunctionType::Pointer m_PerformanceFunction; 00118 }; 00119 00120 } // end namespace Statistics 00121 } // end namespace itk 00122 #ifndef ITK_MANUAL_INSTANTIATION 00123 #include "itkTrainingFunctionBase.hxx" 00124 #endif 00125 00126 #endif 00127