00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018 #ifndef __itkWeightSetBase_h
00019 #define __itkWeightSetBase_h
00020
00021 #include "itkLightProcessObject.h"
00022 #include <vnl/vnl_matrix.h>
00023 #include <vnl/vnl_diag_matrix.h>
00024 #include "itkMacro.h"
00025 #include "itkVector.h"
00026 #include "itkMersenneTwisterRandomVariateGenerator.h"
00027 #include <math.h>
00028 #include <stdlib.h>
00029
00030 namespace itk
00031 {
00032 namespace Statistics
00033 {
00034
00035 template<class TMeasurementVector, class TTargetVector>
00036 class WeightSetBase : public LightProcessObject
00037 {
00038 public:
00039
00040 typedef WeightSetBase Self;
00041 typedef LightProcessObject Superclass;
00042 typedef SmartPointer<Self> Pointer;
00043 typedef SmartPointer<const Self> ConstPointer;
00044 itkTypeMacro(WeightSetBase, LightProcessObject);
00045
00046 typedef MersenneTwisterRandomVariateGenerator RandomVariateGeneratorType;
00047
00048 typedef typename TMeasurementVector::ValueType ValueType;
00049 typedef ValueType* ValuePointer;
00050 typedef const ValueType* ValueConstPointer;
00051
00052 void Initialize();
00053
00054 ValueType RandomWeightValue(ValueType low, ValueType high);
00055
00056 virtual void ForwardPropagate(ValuePointer inputlayeroutputvalues);
00057 virtual void BackwardPropagate(ValuePointer inputerror);
00058
00059 void SetConnectivityMatrix(vnl_matrix < int>);
00060
00061 void SetNumberOfInputNodes(unsigned int n);
00062 unsigned int GetNumberOfInputNodes() const;
00063
00064 void SetNumberOfOutputNodes(unsigned int n);
00065 unsigned int GetNumberOfOutputNodes() const;
00066
00067 void SetRange(ValueType Range);
00068
00069 virtual ValuePointer GetOutputValues();
00070 virtual ValuePointer GetInputValues();
00071
00072 ValuePointer GetTotalDeltaValues();
00073
00074 ValuePointer GetTotalDeltaBValues();
00075
00076 ValuePointer GetDeltaValues();
00077
00078 void SetDeltaValues(ValuePointer);
00079
00080 void SetDWValues(ValuePointer);
00081
00082 void SetDBValues(ValuePointer);
00083
00084 ValuePointer GetDeltaBValues();
00085
00086 void SetDeltaBValues(ValuePointer);
00087
00088 ValuePointer GetDWValues();
00089
00090 ValuePointer GetPrevDWValues();
00091
00092 ValuePointer GetPrevDBValues();
00093
00094 ValuePointer GetPrev_m_2DWValues();
00095
00096 ValuePointer GetPrevDeltaValues();
00097
00098 ValuePointer GetPrev_m_2DeltaValues();
00099
00100 ValuePointer GetPrevDeltaBValues();
00101
00102 ValuePointer GetWeightValues();
00103 ValueConstPointer GetWeightValues() const;
00104
00105
00106 void SetWeightValues(ValuePointer weights);
00107
00108 virtual void UpdateWeights(ValueType LearningRate);
00109
00110 itkSetMacro( Momentum, ValueType );
00111 itkGetConstReferenceMacro( Momentum, ValueType );
00112
00113 itkSetMacro( Bias, ValueType );
00114 itkGetConstReferenceMacro( Bias, ValueType );
00115
00116 itkSetMacro( FirstPass, bool );
00117 itkGetConstMacro( FirstPass, bool );
00118
00119 itkSetMacro( SecondPass, bool );
00120 itkGetConstMacro( SecondPass, bool );
00121
00122 void InitializeWeights();
00123
00124 itkSetMacro(WeightSetId,unsigned int);
00125 itkGetConstMacro(WeightSetId,unsigned int);
00126
00127 itkSetMacro(InputLayerId,unsigned int);
00128 itkGetConstMacro(InputLayerId,unsigned int);
00129
00130 itkSetMacro(OutputLayerId,unsigned int);
00131 itkGetConstMacro(OutputLayerId,unsigned int);
00132
00133 protected:
00134
00135 WeightSetBase();
00136 ~WeightSetBase();
00137
00139 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00140
00141 typename RandomVariateGeneratorType::Pointer m_RandomGenerator;
00142
00143 unsigned int m_NumberOfInputNodes;
00144 unsigned int m_NumberOfOutputNodes;
00145 vnl_matrix<ValueType> m_OutputValues;
00146 vnl_matrix<ValueType> m_InputErrorValues;
00147
00148
00149
00150
00151
00152
00153
00154 vnl_matrix<ValueType> m_DW;
00155 vnl_matrix<ValueType> m_DW_new;
00156 vnl_matrix<ValueType> m_DW_m_1;
00157 vnl_matrix<ValueType> m_DW_m_2;
00158 vnl_matrix<ValueType> m_DW_m;
00159
00160 vnl_vector<ValueType> m_DB;
00161 vnl_vector<ValueType> m_DB_new;
00162 vnl_vector<ValueType> m_DB_m_1;
00163 vnl_vector<ValueType> m_DB_m_2;
00164
00165 vnl_matrix<ValueType> m_Del;
00166 vnl_matrix<ValueType> m_Del_new;
00167 vnl_matrix<ValueType> m_Del_m_1;
00168 vnl_matrix<ValueType> m_Del_m_2;
00169
00170 vnl_vector<ValueType> m_Delb;
00171 vnl_vector<ValueType> m_Delb_new;
00172 vnl_vector<ValueType> m_Delb_m_1;
00173 vnl_vector<ValueType> m_Delb_m_2;
00174
00175 vnl_matrix<ValueType> m_InputLayerOutput;
00176 vnl_matrix<ValueType> m_WeightMatrix;
00177
00178 vnl_matrix<int> m_ConnectivityMatrix;
00179
00180 ValueType m_Momentum;
00181 ValueType m_Bias;
00182 bool m_FirstPass;
00183 bool m_SecondPass;
00184 ValueType m_Range;
00185
00186 unsigned int m_InputLayerId;
00187 unsigned int m_OutputLayerId;
00188 unsigned int m_WeightSetId;
00189
00190 };
00191
00192 }
00193 }
00194
00195 #ifndef ITK_MANUAL_INSTANTIATION
00196 #include "itkWeightSetBase.txx"
00197 #endif
00198
00199
00200 #endif
00201