Main Page   Groups   Namespace List   Class Hierarchy   Alphabetical List   Compound List   File List   Namespace Members   Compound Members   File Members   Concepts

itkWeightSetBase.h

Go to the documentation of this file.
00001 /*=========================================================================
00002 
00003 Program:   Insight Segmentation & Registration Toolkit
00004 Module:    $RCSfile: itkWeightSetBase.h,v $
00005 Language:  C++
00006 Date:      $Date: 2009-01-28 21:04:59 $
00007 Version:   $Revision: 1.10 $
00008 
00009 Copyright (c) Insight Software Consortium. All rights reserved.
00010 See ITKCopyright.txt or http://www.itk.org/HTML/Copyright.htm for details.
00011 
00012 This software is distributed WITHOUT ANY WARRANTY; without even
00013 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
00014 PURPOSE.  See the above copyright notices for more information.
00015 
00016 =========================================================================*/
00017 
00018 #ifndef __itkWeightSetBase_h
00019 #define __itkWeightSetBase_h
00020 
00021 #include "itkLightProcessObject.h"
00022 #include <vnl/vnl_matrix.h>
00023 #include <vnl/vnl_diag_matrix.h>
00024 #include "itkMacro.h"
00025 #include "itkVector.h"
00026 #include "itkMersenneTwisterRandomVariateGenerator.h"
00027 #include <math.h>
00028 #include <stdlib.h>
00029 
00030 namespace itk
00031 {
00032 namespace Statistics
00033 {
00034 
00035 template<class TMeasurementVector, class TTargetVector>
00036 class WeightSetBase : public LightProcessObject
00037 {
00038 public:
00039 
00040   typedef WeightSetBase            Self;
00041   typedef LightProcessObject       Superclass;
00042   typedef SmartPointer<Self>       Pointer;
00043   typedef SmartPointer<const Self> ConstPointer;
00044   itkTypeMacro(WeightSetBase, LightProcessObject);
00045 
00046   typedef MersenneTwisterRandomVariateGenerator  RandomVariateGeneratorType;
00047   
00048   typedef typename TMeasurementVector::ValueType ValueType;
00049   typedef ValueType*                             ValuePointer;
00050   typedef const ValueType*                       ValueConstPointer;
00051 
00052   void Initialize();
00053 
00054   ValueType RandomWeightValue(ValueType low, ValueType high);
00055 
00056   virtual void ForwardPropagate(ValuePointer inputlayeroutputvalues);
00057   virtual void BackwardPropagate(ValuePointer inputerror);
00058 
00059   void SetConnectivityMatrix(vnl_matrix < int>);
00060 
00061   void SetNumberOfInputNodes(unsigned int n);
00062   unsigned int GetNumberOfInputNodes() const;
00063 
00064   void SetNumberOfOutputNodes(unsigned int n);
00065   unsigned int GetNumberOfOutputNodes() const;
00066   
00067   void SetRange(ValueType Range);
00068 
00069   virtual ValuePointer GetOutputValues();
00070   virtual ValuePointer GetInputValues();
00071 
00072   ValuePointer GetTotalDeltaValues();
00073   ValuePointer GetTotalDeltaBValues();
00074 
00075   ValuePointer GetDeltaValues();
00076   
00077   void SetDeltaValues(ValuePointer);
00078   void SetDWValues(ValuePointer);
00079   void SetDBValues(ValuePointer);
00080   ValuePointer GetDeltaBValues();
00081   void SetDeltaBValues(ValuePointer);
00082   ValuePointer GetDWValues();
00083   ValuePointer GetPrevDWValues();
00084   ValuePointer GetPrevDBValues();
00085   ValuePointer GetPrev_m_2DWValues();
00086   ValuePointer GetPrevDeltaValues();
00087   ValuePointer GetPrev_m_2DeltaValues();
00088   ValuePointer GetPrevDeltaBValues();
00089   ValuePointer GetWeightValues();
00090   ValueConstPointer GetWeightValues() const;
00091 
00092   void SetWeightValues(ValuePointer weights);
00093   virtual void UpdateWeights(ValueType LearningRate);
00094 
00095   itkSetMacro( Momentum, ValueType );
00096   itkGetConstReferenceMacro( Momentum, ValueType );
00097 
00098   itkSetMacro( Bias, ValueType );
00099   itkGetConstReferenceMacro( Bias, ValueType );
00100 
00101   itkSetMacro( FirstPass, bool );
00102   itkGetConstMacro( FirstPass, bool );
00103 
00104   itkSetMacro( SecondPass, bool );
00105   itkGetConstMacro( SecondPass, bool );
00106 
00107   void InitializeWeights();
00108 
00109   itkSetMacro(WeightSetId,unsigned int);
00110   itkGetConstMacro(WeightSetId,unsigned int);
00111 
00112   itkSetMacro(InputLayerId,unsigned int);
00113   itkGetConstMacro(InputLayerId,unsigned int);
00114 
00115   itkSetMacro(OutputLayerId,unsigned int);
00116   itkGetConstMacro(OutputLayerId,unsigned int);
00117 
00118 protected:
00119 
00120   WeightSetBase();
00121   ~WeightSetBase();
00122 
00124   virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00125 
00126   typename RandomVariateGeneratorType::Pointer m_RandomGenerator;
00127 
00128   unsigned int              m_NumberOfInputNodes;
00129   unsigned int              m_NumberOfOutputNodes;
00130   vnl_matrix<ValueType>     m_OutputValues;
00131   vnl_matrix<ValueType>     m_InputErrorValues;
00132 
00133   // weight updates dw=lr * del *y
00134   // DW= current
00135   // DW_m_1 = previous
00136   // DW_m_2= second to last
00137   // same applies for delta and bias values
00138 
00139   vnl_matrix<ValueType> m_DW;            // delta valies for weight update
00140   vnl_matrix<ValueType> m_DW_new;        // delta valies for weight update
00141   vnl_matrix<ValueType> m_DW_m_1;        // delta valies for weight update
00142   vnl_matrix<ValueType> m_DW_m_2;        // delta valies for weight update
00143   vnl_matrix<ValueType> m_DW_m;          // delta valies for weight update
00144   
00145   vnl_vector<ValueType> m_DB;            // delta values for bias update
00146   vnl_vector<ValueType> m_DB_new;        // delta values for bias update
00147   vnl_vector<ValueType> m_DB_m_1;        // delta values for bias update
00148   vnl_vector<ValueType> m_DB_m_2;        // delta values for bias update
00149 
00150   vnl_matrix<ValueType> m_Del;           // dw=lr * del * y
00151   vnl_matrix<ValueType> m_Del_new;       // dw=lr * del * y
00152   vnl_matrix<ValueType> m_Del_m_1;       // dw=lr * del * y
00153   vnl_matrix<ValueType> m_Del_m_2;       // dw=lr * del * y
00154 
00155   vnl_vector<ValueType> m_Delb;          // delta values for bias update
00156   vnl_vector<ValueType> m_Delb_new;      // delta values for bias update
00157   vnl_vector<ValueType> m_Delb_m_1;      // delta values for bias update
00158   vnl_vector<ValueType> m_Delb_m_2;      // delta values for bias update
00159 
00160   vnl_matrix<ValueType> m_InputLayerOutput;
00161   vnl_matrix<ValueType> m_WeightMatrix;  // composed of weights and a column
00162   // of biases
00163   vnl_matrix<int>       m_ConnectivityMatrix;
00164 
00165   ValueType m_Momentum;
00166   ValueType m_Bias;
00167   bool      m_FirstPass;
00168   bool      m_SecondPass;
00169   ValueType m_Range;
00170 
00171   unsigned int m_InputLayerId;
00172   unsigned int m_OutputLayerId;
00173   unsigned int m_WeightSetId;
00174 
00175 };  //class
00176 
00177 } // end namespace Statistics
00178 } // end namespace itk
00179 
00180 #ifndef ITK_MANUAL_INSTANTIATION
00181 #include "itkWeightSetBase.txx"
00182 #endif
00183 
00184 
00185 #endif
00186 

Generated at Mon Jul 12 2010 20:19:10 for ITK by doxygen 1.7.1 written by Dimitri van Heesch, © 1997-2000