00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018 #ifndef __itkTrainingFunctionBase_h
00019 #define __itkTrainingFunctionBase_h
00020
00021 #include <iostream>
00022 #include "itkLightProcessObject.h"
00023 #include "itkNeuralNetworkObject.h"
00024 #include "itkSquaredDifferenceErrorFunction.h"
00025 #include "itkMeanSquaredErrorFunction.h"
00026 namespace itk
00027 {
00028 namespace Statistics
00029 {
00030
00031 template<class TSample, class TOutput, class ScalarType>
00032 class TrainingFunctionBase : public LightProcessObject
00033 {
00034 public:
00035 typedef TrainingFunctionBase Self;
00036 typedef LightProcessObject Superclass;
00037 typedef SmartPointer<Self> Pointer;
00038 typedef SmartPointer<const Self> ConstPointer;
00039
00041 itkTypeMacro(TrainingFunctionBase, LightProcessObject);
00042
00044 itkNewMacro(Self);
00045
00046 typedef ScalarType ValueType;
00047 typedef typename TSample::MeasurementVectorType VectorType;
00048 typedef typename TOutput::MeasurementVectorType OutputVectorType;
00049 typedef Array<ValueType> InternalVectorType;
00050
00051 typedef std::vector<VectorType> InputSampleVectorType;
00052 typedef std::vector<OutputVectorType> OutputSampleVectorType;
00053 typedef NeuralNetworkObject<VectorType, OutputVectorType> NetworkType;
00054 typedef ErrorFunctionBase<InternalVectorType, ScalarType> PerformanceFunctionType;
00055 typedef SquaredDifferenceErrorFunction<InternalVectorType, ScalarType> DefaultPerformanceType;
00056
00057
00058 void SetTrainingSamples(TSample* samples);
00059 void SetTargetValues(TOutput* targets);
00060 void SetLearningRate(ValueType);
00061
00062 ValueType GetLearningRate();
00063
00064 itkSetMacro(Iterations, long);
00065 itkGetConstReferenceMacro(Iterations, long);
00066
00067 void SetPerformanceFunction(PerformanceFunctionType* f);
00068
00069 virtual void
00070 Train(NetworkType* itkNotUsed(net), TSample* itkNotUsed(samples), TOutput* itkNotUsed(targets))
00071 {
00072
00073 };
00074
00075 inline VectorType
00076 defaultconverter(typename TSample::MeasurementVectorType v)
00077 {
00078 VectorType temp;
00079 for (unsigned int i = 0; i < v.Size(); i++)
00080 {
00081 temp[i] = static_cast<ScalarType>(v[i]) ;
00082 }
00083 return temp;
00084 }
00085
00086 inline OutputVectorType
00087 targetconverter(typename TOutput::MeasurementVectorType v)
00088 {
00089 OutputVectorType temp;
00090
00091 for (unsigned int i = 0; i < v.Size(); i++)
00092 {
00093 temp[i] = static_cast<ScalarType>(v[i]) ;
00094 }
00095 return temp;
00096 }
00097
00098 protected:
00099
00100 TrainingFunctionBase();
00101 ~TrainingFunctionBase(){};
00102
00104 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00105
00106 TSample* m_TrainingSamples;
00107 TOutput* m_SampleTargets;
00108 InputSampleVectorType m_InputSamples;
00109 OutputSampleVectorType m_Targets;
00110 long m_Iterations;
00111 ValueType m_LearningRate;
00112 typename PerformanceFunctionType::Pointer m_PerformanceFunction;
00113 };
00114
00115 }
00116 }
00117 #ifndef ITK_MANUAL_INSTANTIATION
00118 #include "itkTrainingFunctionBase.txx"
00119 #endif
00120
00121 #endif
00122