00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018 #ifndef __itkWeightSetBase_h
00019 #define __itkWeightSetBase_h
00020
00021 #include "itkLightProcessObject.h"
00022 #include <vnl/vnl_matrix.h>
00023 #include <vnl/vnl_diag_matrix.h>
00024 #include "itkMacro.h"
00025 #include "itkVector.h"
00026 #include "itkMersenneTwisterRandomVariateGenerator.h"
00027 #include <math.h>
00028 #include <stdlib.h>
00029
00030 namespace itk
00031 {
00032 namespace Statistics
00033 {
00034
00035 template<class TMeasurementVector, class TTargetVector>
00036 class WeightSetBase : public LightProcessObject
00037 {
00038 public:
00039
00040 typedef WeightSetBase Self;
00041 typedef LightProcessObject Superclass;
00042 typedef SmartPointer<Self> Pointer;
00043 typedef SmartPointer<const Self> ConstPointer;
00044 itkTypeMacro(WeightSetBase, LightProcessObject);
00045
00046 typedef MersenneTwisterRandomVariateGenerator RandomVariateGeneratorType;
00047
00048 typedef typename TMeasurementVector::ValueType ValueType;
00049 typedef ValueType* ValuePointer;
00050 typedef const ValueType* ValueConstPointer;
00051
00052 void Initialize();
00053
00054 ValueType RandomWeightValue(ValueType low, ValueType high);
00055
00056 virtual void ForwardPropagate(ValuePointer inputlayeroutputvalues);
00057 virtual void BackwardPropagate(ValuePointer inputerror);
00058
00059 void SetConnectivityMatrix(vnl_matrix < int>);
00060
00061 void SetNumberOfInputNodes(unsigned int n);
00062 unsigned int GetNumberOfInputNodes() const;
00063
00064 void SetNumberOfOutputNodes(unsigned int n);
00065 unsigned int GetNumberOfOutputNodes() const;
00066
00067 void SetRange(ValueType Range);
00068
00069 virtual ValuePointer GetOutputValues();
00070 virtual ValuePointer GetInputValues();
00071
00072 ValuePointer GetTotalDeltaValues();
00073 ValuePointer GetTotalDeltaBValues();
00074
00075 ValuePointer GetDeltaValues();
00076
00077 void SetDeltaValues(ValuePointer);
00078 void SetDWValues(ValuePointer);
00079 void SetDBValues(ValuePointer);
00080 ValuePointer GetDeltaBValues();
00081 void SetDeltaBValues(ValuePointer);
00082 ValuePointer GetDWValues();
00083 ValuePointer GetPrevDWValues();
00084 ValuePointer GetPrevDBValues();
00085 ValuePointer GetPrev_m_2DWValues();
00086 ValuePointer GetPrevDeltaValues();
00087 ValuePointer GetPrev_m_2DeltaValues();
00088 ValuePointer GetPrevDeltaBValues();
00089 ValuePointer GetWeightValues();
00090 ValueConstPointer GetWeightValues() const;
00091
00092 void SetWeightValues(ValuePointer weights);
00093 virtual void UpdateWeights(ValueType LearningRate);
00094
00095 itkSetMacro( Momentum, ValueType );
00096 itkGetConstReferenceMacro( Momentum, ValueType );
00097
00098 itkSetMacro( Bias, ValueType );
00099 itkGetConstReferenceMacro( Bias, ValueType );
00100
00101 itkSetMacro( FirstPass, bool );
00102 itkGetConstMacro( FirstPass, bool );
00103
00104 itkSetMacro( SecondPass, bool );
00105 itkGetConstMacro( SecondPass, bool );
00106
00107 void InitializeWeights();
00108
00109 itkSetMacro(WeightSetId,unsigned int);
00110 itkGetConstMacro(WeightSetId,unsigned int);
00111
00112 itkSetMacro(InputLayerId,unsigned int);
00113 itkGetConstMacro(InputLayerId,unsigned int);
00114
00115 itkSetMacro(OutputLayerId,unsigned int);
00116 itkGetConstMacro(OutputLayerId,unsigned int);
00117
00118 protected:
00119
00120 WeightSetBase();
00121 ~WeightSetBase();
00122
00124 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00125
00126 typename RandomVariateGeneratorType::Pointer m_RandomGenerator;
00127
00128 unsigned int m_NumberOfInputNodes;
00129 unsigned int m_NumberOfOutputNodes;
00130 vnl_matrix<ValueType> m_OutputValues;
00131 vnl_matrix<ValueType> m_InputErrorValues;
00132
00133
00134
00135
00136
00137
00138
00139 vnl_matrix<ValueType> m_DW;
00140 vnl_matrix<ValueType> m_DW_new;
00141 vnl_matrix<ValueType> m_DW_m_1;
00142 vnl_matrix<ValueType> m_DW_m_2;
00143 vnl_matrix<ValueType> m_DW_m;
00144
00145 vnl_vector<ValueType> m_DB;
00146 vnl_vector<ValueType> m_DB_new;
00147 vnl_vector<ValueType> m_DB_m_1;
00148 vnl_vector<ValueType> m_DB_m_2;
00149
00150 vnl_matrix<ValueType> m_Del;
00151 vnl_matrix<ValueType> m_Del_new;
00152 vnl_matrix<ValueType> m_Del_m_1;
00153 vnl_matrix<ValueType> m_Del_m_2;
00154
00155 vnl_vector<ValueType> m_Delb;
00156 vnl_vector<ValueType> m_Delb_new;
00157 vnl_vector<ValueType> m_Delb_m_1;
00158 vnl_vector<ValueType> m_Delb_m_2;
00159
00160 vnl_matrix<ValueType> m_InputLayerOutput;
00161 vnl_matrix<ValueType> m_WeightMatrix;
00162
00163 vnl_matrix<int> m_ConnectivityMatrix;
00164
00165 ValueType m_Momentum;
00166 ValueType m_Bias;
00167 bool m_FirstPass;
00168 bool m_SecondPass;
00169 ValueType m_Range;
00170
00171 unsigned int m_InputLayerId;
00172 unsigned int m_OutputLayerId;
00173 unsigned int m_WeightSetId;
00174
00175 };
00176
00177 }
00178 }
00179
00180 #ifndef ITK_MANUAL_INSTANTIATION
00181 #include "itkWeightSetBase.txx"
00182 #endif
00183
00184
00185 #endif
00186