Main Page   Groups   Namespace List   Class Hierarchy   Alphabetical List   Compound List   File List   Namespace Members   Compound Members   File Members   Concepts

itkTrainingFunctionBase.h

Go to the documentation of this file.
00001 /*=========================================================================
00002 
00003   Program:   Insight Segmentation & Registration Toolkit
00004   Module:    $RCSfile: itkTrainingFunctionBase.h,v $
00005   Language:  C++
00006   Date:      $Date: 2007/08/17 13:10:57 $
00007   Version:   $Revision: 1.7 $
00008 
00009   Copyright (c) Insight Software Consortium. All rights reserved.
00010   See ITKCopyright.txt or http://www.itk.org/HTML/Copyright.htm for details.
00011 
00012      This software is distributed WITHOUT ANY WARRANTY; without even 
00013      the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
00014      PURPOSE.  See the above copyright notices for more information.
00015 
00016 =========================================================================*/
00017 
00018 #ifndef __itkTrainingFunctionBase_h
00019 #define __itkTrainingFunctionBase_h
00020 
00021 #include <iostream>
00022 #include "itkLightProcessObject.h"
00023 #include "itkNeuralNetworkObject.h"
00024 #include "itkSquaredDifferenceErrorFunction.h"
00025 #include "itkMeanSquaredErrorFunction.h"
00026 namespace itk
00027 {
00028 namespace Statistics
00029 {
00030 
00031 template<class TSample, class TTargetVector, class ScalarType>
00032 class TrainingFunctionBase : public LightProcessObject
00033 {
00034 public:
00035   typedef TrainingFunctionBase Self;
00036   typedef LightProcessObject Superclass;
00037   typedef SmartPointer<Self> Pointer;
00038   typedef SmartPointer<const Self> ConstPointer;
00039 
00041   itkTypeMacro(TrainingFunctionBase, LightProcessObject);
00042 
00044   itkNewMacro(Self);
00045 
00046   typedef ScalarType ValueType;
00047   typedef typename TSample::MeasurementVectorType VectorType;
00048   typedef typename TTargetVector::MeasurementVectorType OutputVectorType;
00049   typedef Array<ValueType> InternalVectorType;
00050 
00051   typedef std::vector<VectorType> InputSampleVectorType;
00052   typedef std::vector<OutputVectorType> OutputSampleVectorType;
00053   typedef NeuralNetworkObject<VectorType, OutputVectorType> NetworkType;
00054   typedef ErrorFunctionBase<InternalVectorType, ScalarType> PerformanceFunctionType;
00055   typedef SquaredDifferenceErrorFunction<InternalVectorType, ScalarType> DefaultPerformanceType;
00056   //typedef MeanSquaredErrorFunction<InternalVectorType, ScalarType> DefaultPerformanceType;
00057 
00058   void SetTrainingSamples(TSample* samples);
00059   void SetTargetValues(TTargetVector* targets);
00060   void SetLearningRate(ValueType);
00061 
00062   ValueType GetLearningRate();
00063 
00064   itkSetMacro(Iterations, long);
00065   itkGetConstReferenceMacro(Iterations, long);
00066 
00067   void SetPerformanceFunction(PerformanceFunctionType* f);
00068 
00069   virtual void Train(NetworkType* itkNotUsed(net), TSample* itkNotUsed(samples), TTargetVector* itkNotUsed(targets))
00070     {
00071     // not implemented
00072     };
00073 
00074   inline VectorType
00075   defaultconverter(typename TSample::MeasurementVectorType v)
00076     {
00077     VectorType temp;
00078     for (unsigned int i = 0; i < v.Size(); i++)
00079       {
00080       temp[i] = static_cast<ScalarType>(v[i]) ;
00081       }
00082     return temp;
00083     }
00084 
00085   inline OutputVectorType
00086   targetconverter(typename TTargetVector::MeasurementVectorType v)
00087     {
00088     OutputVectorType temp;
00089     
00090     for (unsigned int i = 0; i < v.Size(); i++)
00091       {
00092       temp[i] = static_cast<ScalarType>(v[i]) ;
00093       }
00094     return temp;
00095     }
00096 
00097 protected:
00098 
00099   TrainingFunctionBase();
00100   ~TrainingFunctionBase(){};
00101    
00103   virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00104 
00105   TSample*                m_TrainingSamples;// original samples
00106   TTargetVector*                m_SampleTargets;  // original samples
00107   InputSampleVectorType   m_InputSamples;   // itk::vectors
00108   OutputSampleVectorType  m_Targets;        // itk::vectors
00109   long                    m_Iterations;    
00110   ValueType               m_LearningRate;
00111   typename PerformanceFunctionType::Pointer m_PerformanceFunction;
00112 };
00113 
00114 } // end namespace Statistics
00115 } // end namespace itk
00116 #ifndef ITK_MANUAL_INSTANTIATION
00117 #include "itkTrainingFunctionBase.txx"
00118 #endif
00119 
00120 #endif
00121 

Generated at Sun Sep 23 14:26:51 2007 for ITK by doxygen 1.5.1 written by Dimitri van Heesch, © 1997-2000