00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018 #ifndef __itkWeightSetBase_h
00019 #define __itkWeightSetBase_h
00020
00021 #include "itkLayerBase.h"
00022 #include "itkLightProcessObject.h"
00023 #include <vnl/vnl_matrix.h>
00024 #include <vnl/vnl_diag_matrix.h>
00025 #include "itkMacro.h"
00026 #include "itkVector.h"
00027 #include "itkMersenneTwisterRandomVariateGenerator.h"
00028 #include <math.h>
00029 #include <stdlib.h>
00030
00031 namespace itk
00032 {
00033 namespace Statistics
00034 {
00035
00036 template<class TVector, class TOutput>
00037 class WeightSetBase : public LightProcessObject
00038 {
00039 public:
00040
00041 typedef WeightSetBase Self;
00042 typedef LightProcessObject Superclass;
00043 typedef SmartPointer<Self> Pointer;
00044 typedef SmartPointer<const Self> ConstPointer;
00045 itkTypeMacro(WeightSetBase, LightProcessObject);
00046
00047 typedef MersenneTwisterRandomVariateGenerator RandomVariateGeneratorType;
00048
00049 typedef typename TVector::ValueType ValueType;
00050 typedef ValueType* ValuePointer;
00051 typedef const ValueType* ValueConstPointer;
00052
00053 void Initialize();
00054
00055 ValueType RandomWeightValue(ValueType low, ValueType high);
00056
00057 void ForwardPropagate(ValuePointer inputlayeroutputvalues);
00058
00059 void BackwardPropagate(ValuePointer inputerror);
00060
00061 void SetConnectivityMatrix(vnl_matrix < int>);
00062
00063 void SetNumberOfInputNodes(unsigned int n);
00064 unsigned int GetNumberOfInputNodes() const;
00065
00066 void SetNumberOfOutputNodes(unsigned int n);
00067 unsigned int GetNumberOfOutputNodes() const;
00068
00069 void SetRange(ValueType Range);
00070
00071 ValuePointer GetOutputValues();
00072
00073 ValuePointer GetInputValues();
00074
00075 ValuePointer GetTotalDeltaValues();
00076
00077 ValuePointer GetTotalDeltaBValues();
00078
00079 ValuePointer GetDeltaValues();
00080
00081 void SetDeltaValues(ValuePointer);
00082
00083 void SetDWValues(ValuePointer);
00084
00085 void SetDBValues(ValuePointer);
00086
00087 ValuePointer GetDeltaBValues();
00088
00089 void SetDeltaBValues(ValuePointer);
00090
00091 ValuePointer GetDWValues();
00092
00093 ValuePointer GetPrevDWValues();
00094
00095 ValuePointer GetPrevDBValues();
00096
00097 ValuePointer GetPrev_m_2DWValues();
00098
00099 ValuePointer GetPrevDeltaValues();
00100
00101 ValuePointer GetPrev_m_2DeltaValues();
00102
00103 ValuePointer GetPrevDeltaBValues();
00104
00105 ValuePointer GetWeightValues();
00106 ValueConstPointer GetWeightValues() const;
00107
00108
00109 void SetWeightValues(ValuePointer weights);
00110
00111 void UpdateWeights(ValueType LearningRate);
00112
00113 itkSetMacro( Momentum, ValueType );
00114 itkGetConstReferenceMacro( Momentum, ValueType );
00115
00116 itkSetMacro( Bias, ValueType );
00117 itkGetConstReferenceMacro( Bias, ValueType );
00118
00119 itkSetMacro( FirstPass, bool );
00120 itkGetConstMacro( FirstPass, bool );
00121
00122 itkSetMacro( SecondPass, bool );
00123 itkGetConstMacro( SecondPass, bool );
00124
00125 void InitializeWeights();
00126
00127 itkSetMacro(WeightSetId,unsigned int);
00128 itkGetConstMacro(WeightSetId,unsigned int);
00129
00130 itkSetMacro(InputLayerId,unsigned int);
00131 itkGetConstMacro(InputLayerId,unsigned int);
00132
00133 itkSetMacro(OutputLayerId,unsigned int);
00134 itkGetConstMacro(OutputLayerId,unsigned int);
00135
00136 protected:
00137
00138 WeightSetBase();
00139 ~WeightSetBase();
00140
00142 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00143
00144 typename RandomVariateGeneratorType::Pointer m_RandomGenerator;
00145
00146 unsigned int m_NumberOfInputNodes;
00147 unsigned int m_NumberOfOutputNodes;
00148 vnl_matrix<ValueType> m_OutputValues;
00149 vnl_matrix<ValueType> m_InputErrorValues;
00150
00151
00152
00153
00154
00155
00156
00157 vnl_matrix<ValueType> m_DW;
00158 vnl_matrix<ValueType> m_DW_new;
00159 vnl_matrix<ValueType> m_DW_m_1;
00160 vnl_matrix<ValueType> m_DW_m_2;
00161 vnl_matrix<ValueType> m_DW_m;
00162
00163 vnl_vector<ValueType> m_DB;
00164 vnl_vector<ValueType> m_DB_new;
00165 vnl_vector<ValueType> m_DB_m_1;
00166 vnl_vector<ValueType> m_DB_m_2;
00167
00168 vnl_matrix<ValueType> m_Del;
00169 vnl_matrix<ValueType> m_Del_new;
00170 vnl_matrix<ValueType> m_Del_m_1;
00171 vnl_matrix<ValueType> m_Del_m_2;
00172
00173 vnl_vector<ValueType> m_Delb;
00174 vnl_vector<ValueType> m_Delb_new;
00175 vnl_vector<ValueType> m_Delb_m_1;
00176 vnl_vector<ValueType> m_Delb_m_2;
00177
00178 vnl_matrix<ValueType> m_InputLayerOutput;
00179 vnl_matrix<ValueType> m_WeightMatrix;
00180
00181 vnl_matrix<int> m_ConnectivityMatrix;
00182
00183 ValueType m_Momentum;
00184 ValueType m_Bias;
00185 bool m_FirstPass;
00186 bool m_SecondPass;
00187 ValueType m_Range;
00188
00189 unsigned int m_InputLayerId;
00190 unsigned int m_OutputLayerId;
00191 unsigned int m_WeightSetId;
00192
00193 };
00194
00195 }
00196 }
00197
00198 #ifndef ITK_MANUAL_INSTANTIATION
00199 #include "itkWeightSetBase.txx"
00200 #endif
00201
00202
00203 #endif
00204