ITK  4.0.0
Insight Segmentation and Registration Toolkit
itkWeightSetBase.h
Go to the documentation of this file.
00001 /*=========================================================================
00002  *
00003  *  Copyright Insight Software Consortium
00004  *
00005  *  Licensed under the Apache License, Version 2.0 (the "License");
00006  *  you may not use this file except in compliance with the License.
00007  *  You may obtain a copy of the License at
00008  *
00009  *         http://www.apache.org/licenses/LICENSE-2.0.txt
00010  *
00011  *  Unless required by applicable law or agreed to in writing, software
00012  *  distributed under the License is distributed on an "AS IS" BASIS,
00013  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00014  *  See the License for the specific language governing permissions and
00015  *  limitations under the License.
00016  *
00017  *=========================================================================*/
00018 #ifndef __itkWeightSetBase_h
00019 #define __itkWeightSetBase_h
00020 
00021 #include "itkLightProcessObject.h"
00022 #include "vnl/vnl_matrix.h"
00023 #include "vnl/vnl_diag_matrix.h"
00024 #include "itkVector.h"
00025 #include "itkMersenneTwisterRandomVariateGenerator.h"
00026 #include <math.h>
00027 #include <stdlib.h>
00028 
00029 namespace itk
00030 {
00031 namespace Statistics
00032 {
00039 template<class TMeasurementVector, class TTargetVector>
00040 class WeightSetBase : public LightProcessObject
00041 {
00042 public:
00043 
00044   typedef WeightSetBase            Self;
00045   typedef LightProcessObject       Superclass;
00046   typedef SmartPointer<Self>       Pointer;
00047   typedef SmartPointer<const Self> ConstPointer;
00048   itkTypeMacro(WeightSetBase, LightProcessObject);
00049 
00050   typedef MersenneTwisterRandomVariateGenerator  RandomVariateGeneratorType;
00051 
00052   typedef typename TMeasurementVector::ValueType ValueType;
00053   typedef ValueType*                             ValuePointer;
00054   typedef const ValueType*                       ValueConstPointer;
00055 
00056   void Initialize();
00057 
00058   ValueType RandomWeightValue(ValueType low, ValueType high);
00059 
00060   virtual void ForwardPropagate(ValuePointer inputlayeroutputvalues);
00061   virtual void BackwardPropagate(ValuePointer inputerror);
00062 
00063   void SetConnectivityMatrix(vnl_matrix < int>);
00064 
00065   void SetNumberOfInputNodes(unsigned int n);
00066   unsigned int GetNumberOfInputNodes() const;
00067 
00068   void SetNumberOfOutputNodes(unsigned int n);
00069   unsigned int GetNumberOfOutputNodes() const;
00070 
00071   void SetRange(ValueType Range);
00072 
00073   virtual ValuePointer GetOutputValues();
00074   virtual ValuePointer GetInputValues();
00075 
00076   ValuePointer GetTotalDeltaValues();
00077   ValuePointer GetTotalDeltaBValues();
00078 
00079   ValuePointer GetDeltaValues();
00080 
00081   void SetDeltaValues(ValuePointer);
00082   void SetDWValues(ValuePointer);
00083   void SetDBValues(ValuePointer);
00084   ValuePointer GetDeltaBValues();
00085   void SetDeltaBValues(ValuePointer);
00086   ValuePointer GetDWValues();
00087   ValuePointer GetPrevDWValues();
00088   ValuePointer GetPrevDBValues();
00089   ValuePointer GetPrev_m_2DWValues();
00090   ValuePointer GetPrevDeltaValues();
00091   ValuePointer GetPrev_m_2DeltaValues();
00092   ValuePointer GetPrevDeltaBValues();
00093   ValuePointer GetWeightValues();
00094   ValueConstPointer GetWeightValues() const;
00095 
00096   void SetWeightValues(ValuePointer weights);
00097   virtual void UpdateWeights(ValueType LearningRate);
00098 
00099   itkSetMacro( Momentum, ValueType );
00100   itkGetConstReferenceMacro( Momentum, ValueType );
00101 
00102   itkSetMacro( Bias, ValueType );
00103   itkGetConstReferenceMacro( Bias, ValueType );
00104 
00105   itkSetMacro( FirstPass, bool );
00106   itkGetConstMacro( FirstPass, bool );
00107 
00108   itkSetMacro( SecondPass, bool );
00109   itkGetConstMacro( SecondPass, bool );
00110 
00111   void InitializeWeights();
00112 
00113   itkSetMacro(WeightSetId,unsigned int);
00114   itkGetConstMacro(WeightSetId,unsigned int);
00115 
00116   itkSetMacro(InputLayerId,unsigned int);
00117   itkGetConstMacro(InputLayerId,unsigned int);
00118 
00119   itkSetMacro(OutputLayerId,unsigned int);
00120   itkGetConstMacro(OutputLayerId,unsigned int);
00121 
00122 protected:
00123 
00124   WeightSetBase();
00125   ~WeightSetBase();
00126 
00128   virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00129 
00130   typename RandomVariateGeneratorType::Pointer m_RandomGenerator;
00131 
00132   unsigned int              m_NumberOfInputNodes;
00133   unsigned int              m_NumberOfOutputNodes;
00134   vnl_matrix<ValueType>     m_OutputValues;
00135   vnl_matrix<ValueType>     m_InputErrorValues;
00136 
00137   // weight updates dw=lr * del *y
00138   // DW= current
00139   // DW_m_1 = previous
00140   // DW_m_2= second to last
00141   // same applies for delta and bias values
00142 
00143   vnl_matrix<ValueType> m_DW;            // delta valies for weight update
00144   vnl_matrix<ValueType> m_DW_new;        // delta valies for weight update
00145   vnl_matrix<ValueType> m_DW_m_1;        // delta valies for weight update
00146   vnl_matrix<ValueType> m_DW_m_2;        // delta valies for weight update
00147   vnl_matrix<ValueType> m_DW_m;          // delta valies for weight update
00148 
00149   vnl_vector<ValueType> m_DB;            // delta values for bias update
00150   vnl_vector<ValueType> m_DB_new;        // delta values for bias update
00151   vnl_vector<ValueType> m_DB_m_1;        // delta values for bias update
00152   vnl_vector<ValueType> m_DB_m_2;        // delta values for bias update
00153 
00154   vnl_matrix<ValueType> m_Del;           // dw=lr * del * y
00155   vnl_matrix<ValueType> m_Del_new;       // dw=lr * del * y
00156   vnl_matrix<ValueType> m_Del_m_1;       // dw=lr * del * y
00157   vnl_matrix<ValueType> m_Del_m_2;       // dw=lr * del * y
00158 
00159   vnl_vector<ValueType> m_Delb;          // delta values for bias update
00160   vnl_vector<ValueType> m_Delb_new;      // delta values for bias update
00161   vnl_vector<ValueType> m_Delb_m_1;      // delta values for bias update
00162   vnl_vector<ValueType> m_Delb_m_2;      // delta values for bias update
00163 
00164   vnl_matrix<ValueType> m_InputLayerOutput;
00165   vnl_matrix<ValueType> m_WeightMatrix;  // composed of weights and a column
00166   // of biases
00167   vnl_matrix<int>       m_ConnectivityMatrix;
00168 
00169   ValueType m_Momentum;
00170   ValueType m_Bias;
00171   bool      m_FirstPass;
00172   bool      m_SecondPass;
00173   ValueType m_Range;
00174 
00175   unsigned int m_InputLayerId;
00176   unsigned int m_OutputLayerId;
00177   unsigned int m_WeightSetId;
00178 
00179 };  //class
00180 
00181 } // end namespace Statistics
00182 } // end namespace itk
00183 
00184 #ifndef ITK_MANUAL_INSTANTIATION
00185 #include "itkWeightSetBase.hxx"
00186 #endif
00187 
00188 
00189 #endif
00190