ITK
4.1.0
Insight Segmentation and Registration Toolkit
|
00001 /*========================================================================= 00002 * 00003 * Copyright Insight Software Consortium 00004 * 00005 * Licensed under the Apache License, Version 2.0 (the "License"); 00006 * you may not use this file except in compliance with the License. 00007 * You may obtain a copy of the License at 00008 * 00009 * http://www.apache.org/licenses/LICENSE-2.0.txt 00010 * 00011 * Unless required by applicable law or agreed to in writing, software 00012 * distributed under the License is distributed on an "AS IS" BASIS, 00013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00014 * See the License for the specific language governing permissions and 00015 * limitations under the License. 00016 * 00017 *=========================================================================*/ 00018 #ifndef __itkWeightSetBase_h 00019 #define __itkWeightSetBase_h 00020 00021 #include "itkLightProcessObject.h" 00022 #include "vnl/vnl_matrix.h" 00023 #include "vnl/vnl_diag_matrix.h" 00024 #include "itkVector.h" 00025 #include "itkMersenneTwisterRandomVariateGenerator.h" 00026 #include <math.h> 00027 #include <stdlib.h> 00028 00029 namespace itk 00030 { 00031 namespace Statistics 00032 { 00039 template<class TMeasurementVector, class TTargetVector> 00040 class WeightSetBase : public LightProcessObject 00041 { 00042 public: 00043 00044 typedef WeightSetBase Self; 00045 typedef LightProcessObject Superclass; 00046 typedef SmartPointer<Self> Pointer; 00047 typedef SmartPointer<const Self> ConstPointer; 00048 itkTypeMacro(WeightSetBase, LightProcessObject); 00049 00050 typedef MersenneTwisterRandomVariateGenerator RandomVariateGeneratorType; 00051 00052 typedef typename TMeasurementVector::ValueType ValueType; 00053 typedef ValueType* ValuePointer; 00054 typedef const ValueType* ValueConstPointer; 00055 00056 void Initialize(); 00057 00058 ValueType RandomWeightValue(ValueType low, ValueType high); 00059 00060 virtual void ForwardPropagate(ValuePointer inputlayeroutputvalues); 00061 virtual void BackwardPropagate(ValuePointer inputerror); 00062 00063 void SetConnectivityMatrix(vnl_matrix < int>); 00064 00065 void SetNumberOfInputNodes(unsigned int n); 00066 unsigned int GetNumberOfInputNodes() const; 00067 00068 void SetNumberOfOutputNodes(unsigned int n); 00069 unsigned int GetNumberOfOutputNodes() const; 00070 00071 void SetRange(ValueType Range); 00072 00073 virtual ValuePointer GetOutputValues(); 00074 virtual ValuePointer GetInputValues(); 00075 00076 ValuePointer GetTotalDeltaValues(); 00077 ValuePointer GetTotalDeltaBValues(); 00078 00079 ValuePointer GetDeltaValues(); 00080 00081 void SetDeltaValues(ValuePointer); 00082 void SetDWValues(ValuePointer); 00083 void SetDBValues(ValuePointer); 00084 ValuePointer GetDeltaBValues(); 00085 void SetDeltaBValues(ValuePointer); 00086 ValuePointer GetDWValues(); 00087 ValuePointer GetPrevDWValues(); 00088 ValuePointer GetPrevDBValues(); 00089 ValuePointer GetPrev_m_2DWValues(); 00090 ValuePointer GetPrevDeltaValues(); 00091 ValuePointer GetPrev_m_2DeltaValues(); 00092 ValuePointer GetPrevDeltaBValues(); 00093 ValuePointer GetWeightValues(); 00094 ValueConstPointer GetWeightValues() const; 00095 00096 void SetWeightValues(ValuePointer weights); 00097 virtual void UpdateWeights(ValueType LearningRate); 00098 00099 itkSetMacro( Momentum, ValueType ); 00100 itkGetConstReferenceMacro( Momentum, ValueType ); 00101 00102 itkSetMacro( Bias, ValueType ); 00103 itkGetConstReferenceMacro( Bias, ValueType ); 00104 00105 itkSetMacro( FirstPass, bool ); 00106 itkGetConstMacro( FirstPass, bool ); 00107 00108 itkSetMacro( SecondPass, bool ); 00109 itkGetConstMacro( SecondPass, bool ); 00110 00111 void InitializeWeights(); 00112 00113 itkSetMacro(WeightSetId,unsigned int); 00114 itkGetConstMacro(WeightSetId,unsigned int); 00115 00116 itkSetMacro(InputLayerId,unsigned int); 00117 itkGetConstMacro(InputLayerId,unsigned int); 00118 00119 itkSetMacro(OutputLayerId,unsigned int); 00120 itkGetConstMacro(OutputLayerId,unsigned int); 00121 00122 protected: 00123 00124 WeightSetBase(); 00125 ~WeightSetBase(); 00126 00128 virtual void PrintSelf( std::ostream& os, Indent indent ) const; 00129 00130 typename RandomVariateGeneratorType::Pointer m_RandomGenerator; 00131 00132 unsigned int m_NumberOfInputNodes; 00133 unsigned int m_NumberOfOutputNodes; 00134 vnl_matrix<ValueType> m_OutputValues; 00135 vnl_matrix<ValueType> m_InputErrorValues; 00136 00137 // weight updates dw=lr * del *y 00138 // DW= current 00139 // DW_m_1 = previous 00140 // DW_m_2= second to last 00141 // same applies for delta and bias values 00142 00143 vnl_matrix<ValueType> m_DW; // delta valies for weight update 00144 vnl_matrix<ValueType> m_DW_new; // delta valies for weight update 00145 vnl_matrix<ValueType> m_DW_m_1; // delta valies for weight update 00146 vnl_matrix<ValueType> m_DW_m_2; // delta valies for weight update 00147 vnl_matrix<ValueType> m_DW_m; // delta valies for weight update 00148 00149 vnl_vector<ValueType> m_DB; // delta values for bias update 00150 vnl_vector<ValueType> m_DB_new; // delta values for bias update 00151 vnl_vector<ValueType> m_DB_m_1; // delta values for bias update 00152 vnl_vector<ValueType> m_DB_m_2; // delta values for bias update 00153 00154 vnl_matrix<ValueType> m_Del; // dw=lr * del * y 00155 vnl_matrix<ValueType> m_Del_new; // dw=lr * del * y 00156 vnl_matrix<ValueType> m_Del_m_1; // dw=lr * del * y 00157 vnl_matrix<ValueType> m_Del_m_2; // dw=lr * del * y 00158 00159 vnl_vector<ValueType> m_Delb; // delta values for bias update 00160 vnl_vector<ValueType> m_Delb_new; // delta values for bias update 00161 vnl_vector<ValueType> m_Delb_m_1; // delta values for bias update 00162 vnl_vector<ValueType> m_Delb_m_2; // delta values for bias update 00163 00164 vnl_matrix<ValueType> m_InputLayerOutput; 00165 vnl_matrix<ValueType> m_WeightMatrix; // composed of weights and a column 00166 // of biases 00167 vnl_matrix<int> m_ConnectivityMatrix; 00168 00169 ValueType m_Momentum; 00170 ValueType m_Bias; 00171 bool m_FirstPass; 00172 bool m_SecondPass; 00173 ValueType m_Range; 00174 00175 unsigned int m_InputLayerId; 00176 unsigned int m_OutputLayerId; 00177 unsigned int m_WeightSetId; 00178 00179 }; //class 00180 00181 } // end namespace Statistics 00182 } // end namespace itk 00183 00184 #ifndef ITK_MANUAL_INSTANTIATION 00185 #include "itkWeightSetBase.hxx" 00186 #endif 00187 00188 00189 #endif 00190