Main Page   Groups   Namespace List   Class Hierarchy   Alphabetical List   Compound List   File List   Namespace Members   Compound Members   File Members   Concepts

itkWeightSetBase.h

Go to the documentation of this file.
00001 /*=========================================================================
00002 
00003   Program:   Insight Segmentation & Registration Toolkit
00004   Module:    $RCSfile: itkWeightSetBase.h,v $
00005   Language:  C++
00006   Date:      $Date: 2007/01/30 12:47:27 $
00007   Version:   $Revision: 1.8 $
00008 
00009   Copyright (c) Insight Software Consortium. All rights reserved.
00010   See ITKCopyright.txt or http://www.itk.org/HTML/Copyright.htm for details.
00011 
00012      This software is distributed WITHOUT ANY WARRANTY; without even 
00013      the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
00014      PURPOSE.  See the above copyright notices for more information.
00015 
00016 =========================================================================*/
00017 
00018 #ifndef __itkWeightSetBase_h
00019 #define __itkWeightSetBase_h
00020 
00021 #include "itkLayerBase.h"
00022 #include "itkLightProcessObject.h"
00023 #include <vnl/vnl_matrix.h>
00024 #include <vnl/vnl_diag_matrix.h>
00025 #include "itkMacro.h"
00026 #include "itkVector.h"
00027 #include "itkMersenneTwisterRandomVariateGenerator.h"
00028 #include <math.h>
00029 #include <stdlib.h>
00030 
00031 namespace itk
00032 {
00033 namespace Statistics
00034 {
00035 
00036 template<class TVector, class TOutput>
00037 class WeightSetBase : public LightProcessObject
00038 {
00039 public:
00040   
00041   typedef WeightSetBase            Self;
00042   typedef LightProcessObject       Superclass;
00043   typedef SmartPointer<Self>       Pointer;
00044   typedef SmartPointer<const Self> ConstPointer;
00045   itkTypeMacro(WeightSetBase, LightProcessObject);
00046 
00047   typedef MersenneTwisterRandomVariateGenerator  RandomVariateGeneratorType;
00048 
00049   typedef typename TVector::ValueType ValueType;
00050   typedef ValueType*                  ValuePointer;
00051   typedef const ValueType*            ValueConstPointer;
00052 
00053   void Initialize();
00054 
00055   ValueType RandomWeightValue(ValueType low, ValueType high);
00056 
00057   void ForwardPropagate(ValuePointer inputlayeroutputvalues);
00058 
00059   void BackwardPropagate(ValuePointer inputerror);
00060 
00061   void SetConnectivityMatrix(vnl_matrix < int>);
00062 
00063   void SetNumberOfInputNodes(unsigned int n);
00064   unsigned int GetNumberOfInputNodes() const;
00065 
00066   void SetNumberOfOutputNodes(unsigned int n);
00067   unsigned int GetNumberOfOutputNodes() const;
00068 
00069   void SetRange(ValueType Range);
00070  
00071   ValuePointer GetOutputValues();
00072 
00073   ValuePointer GetInputValues();
00074   
00075   ValuePointer GetTotalDeltaValues();
00076   
00077   ValuePointer GetTotalDeltaBValues();
00078   
00079   ValuePointer GetDeltaValues();
00080 
00081   void SetDeltaValues(ValuePointer);
00082   
00083   void SetDWValues(ValuePointer);
00084 
00085   void SetDBValues(ValuePointer);
00086 
00087   ValuePointer GetDeltaBValues();
00088   
00089   void SetDeltaBValues(ValuePointer);
00090 
00091   ValuePointer GetDWValues();
00092 
00093   ValuePointer GetPrevDWValues();
00094 
00095   ValuePointer GetPrevDBValues();
00096 
00097   ValuePointer GetPrev_m_2DWValues();
00098 
00099   ValuePointer GetPrevDeltaValues();
00100 
00101   ValuePointer GetPrev_m_2DeltaValues();
00102 
00103   ValuePointer GetPrevDeltaBValues();
00104 
00105   ValuePointer GetWeightValues();
00106   ValueConstPointer GetWeightValues() const;
00107 
00108   
00109   void SetWeightValues(ValuePointer weights);
00110 
00111   void UpdateWeights(ValueType LearningRate);
00112 
00113   itkSetMacro( Momentum, ValueType );
00114   itkGetConstReferenceMacro( Momentum, ValueType );
00115 
00116   itkSetMacro( Bias, ValueType );
00117   itkGetConstReferenceMacro( Bias, ValueType );
00118 
00119   itkSetMacro( FirstPass, bool );
00120   itkGetConstMacro( FirstPass, bool );
00121 
00122   itkSetMacro( SecondPass, bool );
00123   itkGetConstMacro( SecondPass, bool );
00124 
00125   void InitializeWeights();
00126 
00127   itkSetMacro(WeightSetId,unsigned int);
00128   itkGetConstMacro(WeightSetId,unsigned int);
00129   
00130   itkSetMacro(InputLayerId,unsigned int);
00131   itkGetConstMacro(InputLayerId,unsigned int);
00132   
00133   itkSetMacro(OutputLayerId,unsigned int);
00134   itkGetConstMacro(OutputLayerId,unsigned int);
00135   
00136 protected: 
00137   
00138   WeightSetBase();
00139   ~WeightSetBase();
00140 
00142   virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00143 
00144   typename RandomVariateGeneratorType::Pointer m_RandomGenerator;
00145 
00146   unsigned int              m_NumberOfInputNodes;
00147   unsigned int              m_NumberOfOutputNodes;
00148   vnl_matrix<ValueType>     m_OutputValues;
00149   vnl_matrix<ValueType>     m_InputErrorValues;
00150   
00151   // weight updates dw=lr * del *y
00152   // DW= current
00153   // DW_m_1 = previous
00154   // DW_m_2= second to last
00155   // same applies for delta and bias values
00156 
00157   vnl_matrix<ValueType> m_DW;            // delta valies for weight update
00158   vnl_matrix<ValueType> m_DW_new;        // delta valies for weight update
00159   vnl_matrix<ValueType> m_DW_m_1;        // delta valies for weight update
00160   vnl_matrix<ValueType> m_DW_m_2;        // delta valies for weight update
00161   vnl_matrix<ValueType> m_DW_m;          // delta valies for weight update
00162   
00163   vnl_vector<ValueType> m_DB;            // delta values for bias update
00164   vnl_vector<ValueType> m_DB_new;        // delta values for bias update
00165   vnl_vector<ValueType> m_DB_m_1;        // delta values for bias update
00166   vnl_vector<ValueType> m_DB_m_2;        // delta values for bias update
00167   
00168   vnl_matrix<ValueType> m_Del;           // dw=lr * del * y
00169   vnl_matrix<ValueType> m_Del_new;       // dw=lr * del * y
00170   vnl_matrix<ValueType> m_Del_m_1;       // dw=lr * del * y
00171   vnl_matrix<ValueType> m_Del_m_2;       // dw=lr * del * y
00172   
00173   vnl_vector<ValueType> m_Delb;          // delta values for bias update
00174   vnl_vector<ValueType> m_Delb_new;      // delta values for bias update
00175   vnl_vector<ValueType> m_Delb_m_1;      // delta values for bias update
00176   vnl_vector<ValueType> m_Delb_m_2;      // delta values for bias update
00177 
00178   vnl_matrix<ValueType> m_InputLayerOutput;
00179   vnl_matrix<ValueType> m_WeightMatrix;  // composed of weights and a column
00180                                          // of biases
00181   vnl_matrix<int>       m_ConnectivityMatrix;
00182   
00183   ValueType m_Momentum;
00184   ValueType m_Bias;
00185   bool      m_FirstPass;
00186   bool      m_SecondPass;
00187   ValueType m_Range;
00188 
00189   unsigned int m_InputLayerId;
00190   unsigned int m_OutputLayerId;
00191   unsigned int m_WeightSetId;
00192 
00193 };  //class
00194 
00195 } // end namespace Statistics
00196 } // end namespace itk
00197 
00198 #ifndef ITK_MANUAL_INSTANTIATION
00199   #include "itkWeightSetBase.txx"
00200 #endif
00201 
00202 
00203 #endif
00204 

Generated at Mon Mar 12 03:31:41 2007 for ITK by doxygen 1.5.1 written by Dimitri van Heesch, © 1997-2000