00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018 #ifndef __itkTrainingFunctionBase_h
00019 #define __itkTrainingFunctionBase_h
00020
00021 #include <iostream>
00022 #include "itkLightProcessObject.h"
00023 #include "itkNeuralNetworkObject.h"
00024 #include "itkSquaredDifferenceErrorFunction.h"
00025 #include "itkMeanSquaredErrorFunction.h"
00026 namespace itk
00027 {
00028 namespace Statistics
00029 {
00030
00031 template<class TSample, class TTargetVector, class ScalarType>
00032 class TrainingFunctionBase : public LightProcessObject
00033 {
00034 public:
00035 typedef TrainingFunctionBase Self;
00036 typedef LightProcessObject Superclass;
00037 typedef SmartPointer<Self> Pointer;
00038 typedef SmartPointer<const Self> ConstPointer;
00039
00041 itkTypeMacro(TrainingFunctionBase, LightProcessObject);
00042
00044 itkNewMacro(Self);
00045
00046 typedef ScalarType ValueType;
00047 typedef typename TSample::MeasurementVectorType VectorType;
00048 typedef typename TTargetVector::MeasurementVectorType OutputVectorType;
00049 typedef Array<ValueType> InternalVectorType;
00050
00051 typedef std::vector<VectorType> InputSampleVectorType;
00052 typedef std::vector<OutputVectorType> OutputSampleVectorType;
00053 typedef NeuralNetworkObject<VectorType, OutputVectorType> NetworkType;
00054 typedef ErrorFunctionBase<InternalVectorType, ScalarType> PerformanceFunctionType;
00055 typedef SquaredDifferenceErrorFunction<InternalVectorType, ScalarType>
00056 DefaultPerformanceType;
00057
00058 void SetTrainingSamples(TSample* samples);
00059 void SetTargetValues(TTargetVector* targets);
00060 void SetLearningRate(ValueType);
00061
00062 ValueType GetLearningRate();
00063
00064 itkSetMacro(Iterations, long);
00065 itkGetConstReferenceMacro(Iterations, long);
00066
00067 void SetPerformanceFunction(PerformanceFunctionType* f);
00068
00069 virtual void Train(NetworkType* itkNotUsed(net), TSample* itkNotUsed(samples), TTargetVector* itkNotUsed(targets))
00070 {
00071
00072 };
00073
00074 inline VectorType
00075 defaultconverter(typename TSample::MeasurementVectorType v)
00076 {
00077 VectorType temp;
00078 for (unsigned int i = 0; i < v.Size(); i++)
00079 {
00080 temp[i] = static_cast<ScalarType>(v[i]);
00081 }
00082 return temp;
00083 }
00084
00085 inline OutputVectorType
00086 targetconverter(typename TTargetVector::MeasurementVectorType v)
00087 {
00088 OutputVectorType temp;
00089
00090 for (unsigned int i = 0; i < v.Size(); i++)
00091 {
00092 temp[i] = static_cast<ScalarType>(v[i]);
00093 }
00094 return temp;
00095 }
00096
00097 protected:
00098
00099 TrainingFunctionBase();
00100 ~TrainingFunctionBase(){};
00101
00103 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00104
00105 TSample* m_TrainingSamples;
00106 TTargetVector* m_SampleTargets;
00107 InputSampleVectorType m_InputSamples;
00108 OutputSampleVectorType m_Targets;
00109 long m_Iterations;
00110 ValueType m_LearningRate;
00111
00112 typename PerformanceFunctionType::Pointer m_PerformanceFunction;
00113 };
00114
00115 }
00116 }
00117 #ifndef ITK_MANUAL_INSTANTIATION
00118 #include "itkTrainingFunctionBase.txx"
00119 #endif
00120
00121 #endif
00122