Main Page   Groups   Namespace List   Class Hierarchy   Alphabetical List   Compound List   File List   Namespace Members   Compound Members   File Members   Concepts

itkWeightSetBase.h

Go to the documentation of this file.
00001 /*=========================================================================
00002 
00003 Program:   Insight Segmentation & Registration Toolkit
00004 Module:    $RCSfile: itkWeightSetBase.h,v $
00005 Language:  C++
00006 Date:      $Date: 2007/08/17 13:10:57 $
00007 Version:   $Revision: 1.9 $
00008 
00009 Copyright (c) Insight Software Consortium. All rights reserved.
00010 See ITKCopyright.txt or http://www.itk.org/HTML/Copyright.htm for details.
00011 
00012 This software is distributed WITHOUT ANY WARRANTY; without even
00013 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
00014 PURPOSE.  See the above copyright notices for more information.
00015 
00016 =========================================================================*/
00017 
00018 #ifndef __itkWeightSetBase_h
00019 #define __itkWeightSetBase_h
00020 
00021 #include "itkLightProcessObject.h"
00022 #include <vnl/vnl_matrix.h>
00023 #include <vnl/vnl_diag_matrix.h>
00024 #include "itkMacro.h"
00025 #include "itkVector.h"
00026 #include "itkMersenneTwisterRandomVariateGenerator.h"
00027 #include <math.h>
00028 #include <stdlib.h>
00029 
00030 namespace itk
00031 {
00032   namespace Statistics
00033     {
00034 
00035     template<class TMeasurementVector, class TTargetVector>
00036       class WeightSetBase : public LightProcessObject
00037         {
00038       public:
00039 
00040         typedef WeightSetBase            Self;
00041         typedef LightProcessObject       Superclass;
00042         typedef SmartPointer<Self>       Pointer;
00043         typedef SmartPointer<const Self> ConstPointer;
00044         itkTypeMacro(WeightSetBase, LightProcessObject);
00045 
00046         typedef MersenneTwisterRandomVariateGenerator  RandomVariateGeneratorType;
00047 
00048         typedef typename TMeasurementVector::ValueType ValueType;
00049         typedef ValueType*                  ValuePointer;
00050         typedef const ValueType*            ValueConstPointer;
00051 
00052         void Initialize();
00053 
00054         ValueType RandomWeightValue(ValueType low, ValueType high);
00055 
00056         virtual void ForwardPropagate(ValuePointer inputlayeroutputvalues);
00057         virtual void BackwardPropagate(ValuePointer inputerror);
00058 
00059         void SetConnectivityMatrix(vnl_matrix < int>);
00060 
00061         void SetNumberOfInputNodes(unsigned int n);
00062         unsigned int GetNumberOfInputNodes() const;
00063 
00064         void SetNumberOfOutputNodes(unsigned int n);
00065         unsigned int GetNumberOfOutputNodes() const;
00066 
00067         void SetRange(ValueType Range);
00068 
00069         virtual ValuePointer GetOutputValues();
00070         virtual ValuePointer GetInputValues();
00071 
00072         ValuePointer GetTotalDeltaValues();
00073 
00074         ValuePointer GetTotalDeltaBValues();
00075 
00076         ValuePointer GetDeltaValues();
00077 
00078         void SetDeltaValues(ValuePointer);
00079 
00080         void SetDWValues(ValuePointer);
00081 
00082         void SetDBValues(ValuePointer);
00083 
00084         ValuePointer GetDeltaBValues();
00085 
00086         void SetDeltaBValues(ValuePointer);
00087 
00088         ValuePointer GetDWValues();
00089 
00090         ValuePointer GetPrevDWValues();
00091 
00092         ValuePointer GetPrevDBValues();
00093 
00094         ValuePointer GetPrev_m_2DWValues();
00095 
00096         ValuePointer GetPrevDeltaValues();
00097 
00098         ValuePointer GetPrev_m_2DeltaValues();
00099 
00100         ValuePointer GetPrevDeltaBValues();
00101 
00102         ValuePointer GetWeightValues();
00103         ValueConstPointer GetWeightValues() const;
00104 
00105 
00106         void SetWeightValues(ValuePointer weights);
00107 
00108         virtual void UpdateWeights(ValueType LearningRate);
00109 
00110         itkSetMacro( Momentum, ValueType );
00111         itkGetConstReferenceMacro( Momentum, ValueType );
00112 
00113         itkSetMacro( Bias, ValueType );
00114         itkGetConstReferenceMacro( Bias, ValueType );
00115 
00116         itkSetMacro( FirstPass, bool );
00117         itkGetConstMacro( FirstPass, bool );
00118 
00119         itkSetMacro( SecondPass, bool );
00120         itkGetConstMacro( SecondPass, bool );
00121 
00122         void InitializeWeights();
00123 
00124         itkSetMacro(WeightSetId,unsigned int);
00125         itkGetConstMacro(WeightSetId,unsigned int);
00126 
00127         itkSetMacro(InputLayerId,unsigned int);
00128         itkGetConstMacro(InputLayerId,unsigned int);
00129 
00130         itkSetMacro(OutputLayerId,unsigned int);
00131         itkGetConstMacro(OutputLayerId,unsigned int);
00132 
00133       protected:
00134 
00135         WeightSetBase();
00136         ~WeightSetBase();
00137 
00139         virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00140 
00141         typename RandomVariateGeneratorType::Pointer m_RandomGenerator;
00142 
00143         unsigned int              m_NumberOfInputNodes;
00144         unsigned int              m_NumberOfOutputNodes;
00145         vnl_matrix<ValueType>     m_OutputValues;
00146         vnl_matrix<ValueType>     m_InputErrorValues;
00147 
00148         // weight updates dw=lr * del *y
00149         // DW= current
00150         // DW_m_1 = previous
00151         // DW_m_2= second to last
00152         // same applies for delta and bias values
00153 
00154         vnl_matrix<ValueType> m_DW;            // delta valies for weight update
00155         vnl_matrix<ValueType> m_DW_new;        // delta valies for weight update
00156         vnl_matrix<ValueType> m_DW_m_1;        // delta valies for weight update
00157         vnl_matrix<ValueType> m_DW_m_2;        // delta valies for weight update
00158         vnl_matrix<ValueType> m_DW_m;          // delta valies for weight update
00159 
00160         vnl_vector<ValueType> m_DB;            // delta values for bias update
00161         vnl_vector<ValueType> m_DB_new;        // delta values for bias update
00162         vnl_vector<ValueType> m_DB_m_1;        // delta values for bias update
00163         vnl_vector<ValueType> m_DB_m_2;        // delta values for bias update
00164 
00165         vnl_matrix<ValueType> m_Del;           // dw=lr * del * y
00166         vnl_matrix<ValueType> m_Del_new;       // dw=lr * del * y
00167         vnl_matrix<ValueType> m_Del_m_1;       // dw=lr * del * y
00168         vnl_matrix<ValueType> m_Del_m_2;       // dw=lr * del * y
00169 
00170         vnl_vector<ValueType> m_Delb;          // delta values for bias update
00171         vnl_vector<ValueType> m_Delb_new;      // delta values for bias update
00172         vnl_vector<ValueType> m_Delb_m_1;      // delta values for bias update
00173         vnl_vector<ValueType> m_Delb_m_2;      // delta values for bias update
00174 
00175         vnl_matrix<ValueType> m_InputLayerOutput;
00176         vnl_matrix<ValueType> m_WeightMatrix;  // composed of weights and a column
00177         // of biases
00178         vnl_matrix<int>       m_ConnectivityMatrix;
00179 
00180         ValueType m_Momentum;
00181         ValueType m_Bias;
00182         bool      m_FirstPass;
00183         bool      m_SecondPass;
00184         ValueType m_Range;
00185 
00186         unsigned int m_InputLayerId;
00187         unsigned int m_OutputLayerId;
00188         unsigned int m_WeightSetId;
00189 
00190         };  //class
00191 
00192     } // end namespace Statistics
00193 } // end namespace itk
00194 
00195 #ifndef ITK_MANUAL_INSTANTIATION
00196 #include "itkWeightSetBase.txx"
00197 #endif
00198 
00199 
00200 #endif
00201 

Generated at Thu Nov 6 01:03:57 2008 for ITK by doxygen 1.5.1 written by Dimitri van Heesch, © 1997-2000