ITK  4.0.0
Insight Segmentation and Registration Toolkit
itkTrainingFunctionBase.h
Go to the documentation of this file.
00001 /*=========================================================================
00002  *
00003  *  Copyright Insight Software Consortium
00004  *
00005  *  Licensed under the Apache License, Version 2.0 (the "License");
00006  *  you may not use this file except in compliance with the License.
00007  *  You may obtain a copy of the License at
00008  *
00009  *         http://www.apache.org/licenses/LICENSE-2.0.txt
00010  *
00011  *  Unless required by applicable law or agreed to in writing, software
00012  *  distributed under the License is distributed on an "AS IS" BASIS,
00013  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00014  *  See the License for the specific language governing permissions and
00015  *  limitations under the License.
00016  *
00017  *=========================================================================*/
00018 #ifndef __itkTrainingFunctionBase_h
00019 #define __itkTrainingFunctionBase_h
00020 
00021 #include <iostream>
00022 #include "itkLightProcessObject.h"
00023 #include "itkNeuralNetworkObject.h"
00024 #include "itkSquaredDifferenceErrorFunction.h"
00025 #include "itkMeanSquaredErrorFunction.h"
00026 namespace itk
00027 {
00028 namespace Statistics
00029 {
00036 template<class TSample, class TTargetVector, class ScalarType>
00037 class TrainingFunctionBase : public LightProcessObject
00038 {
00039 public:
00040   typedef TrainingFunctionBase     Self;
00041   typedef LightProcessObject       Superclass;
00042   typedef SmartPointer<Self>       Pointer;
00043   typedef SmartPointer<const Self> ConstPointer;
00044 
00046   itkTypeMacro(TrainingFunctionBase, LightProcessObject);
00047 
00049   itkNewMacro(Self);
00050 
00051   typedef ScalarType                                    ValueType;
00052   typedef typename TSample::MeasurementVectorType       VectorType;
00053   typedef typename TTargetVector::MeasurementVectorType OutputVectorType;
00054   typedef Array<ValueType>                              InternalVectorType;
00055 
00056   typedef std::vector<VectorType>                           InputSampleVectorType;
00057   typedef std::vector<OutputVectorType>                     OutputSampleVectorType;
00058   typedef NeuralNetworkObject<VectorType, OutputVectorType> NetworkType;
00059   typedef ErrorFunctionBase<InternalVectorType, ScalarType> PerformanceFunctionType;
00060   typedef SquaredDifferenceErrorFunction<InternalVectorType, ScalarType>
00061                                                             DefaultPerformanceType;
00062 
00063   void SetTrainingSamples(TSample* samples);
00064   void SetTargetValues(TTargetVector* targets);
00065   void SetLearningRate(ValueType);
00066 
00067   ValueType GetLearningRate();
00068 
00069   itkSetMacro(Iterations, SizeValueType);
00070   itkGetConstReferenceMacro(Iterations, SizeValueType);
00071 
00072   void SetPerformanceFunction(PerformanceFunctionType* f);
00073 
00074   virtual void Train(NetworkType* itkNotUsed(net), TSample* itkNotUsed(samples), TTargetVector* itkNotUsed(targets))
00075     {
00076     // not implemented
00077     };
00078 
00079   inline VectorType
00080   defaultconverter(typename TSample::MeasurementVectorType v)
00081     {
00082     VectorType temp;
00083     for (unsigned int i = 0; i < v.Size(); i++)
00084       {
00085       temp[i] = static_cast<ScalarType>(v[i]);
00086       }
00087     return temp;
00088     }
00089 
00090   inline OutputVectorType
00091   targetconverter(typename TTargetVector::MeasurementVectorType v)
00092     {
00093     OutputVectorType temp;
00094 
00095     for (unsigned int i = 0; i < v.Size(); i++)
00096       {
00097       temp[i] = static_cast<ScalarType>(v[i]);
00098       }
00099     return temp;
00100     }
00101 
00102 protected:
00103 
00104   TrainingFunctionBase();
00105   ~TrainingFunctionBase(){};
00106 
00108   virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00109 
00110   TSample*                m_TrainingSamples;// original samples
00111   TTargetVector*          m_SampleTargets;  // original samples
00112   InputSampleVectorType   m_InputSamples;   // itk::vectors
00113   OutputSampleVectorType  m_Targets;        // itk::vectors
00114   SizeValueType           m_Iterations;
00115   ValueType               m_LearningRate;
00116 
00117   typename PerformanceFunctionType::Pointer m_PerformanceFunction;
00118 };
00119 
00120 } // end namespace Statistics
00121 } // end namespace itk
00122 #ifndef ITK_MANUAL_INSTANTIATION
00123 #include "itkTrainingFunctionBase.hxx"
00124 #endif
00125 
00126 #endif
00127