ITK  4.13.0
Insight Segmentation and Registration Toolkit
itkWeightSetBase.h
Go to the documentation of this file.
1 /*=========================================================================
2  *
3  * Copyright Insight Software Consortium
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  *=========================================================================*/
18 #ifndef itkWeightSetBase_h
19 #define itkWeightSetBase_h
20 
21 #include "itkLightProcessObject.h"
22 #include "vnl/vnl_matrix.h"
23 #include "vnl/vnl_diag_matrix.h"
24 #include "itkVector.h"
26 #include <cmath>
27 #include <cstdlib>
28 
29 namespace itk
30 {
31 namespace Statistics
32 {
39 template<typename TMeasurementVector, typename TTargetVector>
40 class ITK_TEMPLATE_EXPORT WeightSetBase : public LightProcessObject
41 {
42 public:
43 
49 
51 
52  typedef typename TMeasurementVector::ValueType ValueType;
54  typedef const ValueType* ValueConstPointer;
55 
56  void Initialize();
57 
58  ValueType RandomWeightValue(ValueType low, ValueType high);
59 
60  virtual void ForwardPropagate(ValuePointer inputlayeroutputvalues);
61  virtual void BackwardPropagate(ValuePointer inputerror);
62 
63  void SetConnectivityMatrix(vnl_matrix < int>);
64 
65  void SetNumberOfInputNodes(unsigned int n);
66  unsigned int GetNumberOfInputNodes() const;
67 
68  void SetNumberOfOutputNodes(unsigned int n);
69  unsigned int GetNumberOfOutputNodes() const;
70 
71  void SetRange(ValueType Range);
72 
73  virtual ValuePointer GetOutputValues();
74  virtual ValuePointer GetInputValues();
75 
76  ValuePointer GetTotalDeltaValues();
77  ValuePointer GetTotalDeltaBValues();
78 
79  ValuePointer GetDeltaValues();
80 
81  void SetDeltaValues(ValuePointer);
82  void SetDWValues(ValuePointer);
83  void SetDBValues(ValuePointer);
84  ValuePointer GetDeltaBValues();
85  void SetDeltaBValues(ValuePointer);
86  ValuePointer GetDWValues();
87  ValuePointer GetPrevDWValues();
88  ValuePointer GetPrevDBValues();
89  ValuePointer GetPrev_m_2DWValues();
90  ValuePointer GetPrevDeltaValues();
91  ValuePointer GetPrev_m_2DeltaValues();
92  ValuePointer GetPrevDeltaBValues();
93  ValuePointer GetWeightValues();
94  ValueConstPointer GetWeightValues() const;
95 
96  void SetWeightValues(ValuePointer weights);
97  virtual void UpdateWeights(ValueType LearningRate);
98 
99  itkSetMacro( Momentum, ValueType );
100  itkGetConstReferenceMacro( Momentum, ValueType );
101 
102  itkSetMacro( Bias, ValueType );
103  itkGetConstReferenceMacro( Bias, ValueType );
104 
105  itkSetMacro( FirstPass, bool );
106  itkGetConstMacro( FirstPass, bool );
107 
108  itkSetMacro( SecondPass, bool );
109  itkGetConstMacro( SecondPass, bool );
110 
111  void InitializeWeights();
112 
113  itkSetMacro(WeightSetId,unsigned int);
114  itkGetConstMacro(WeightSetId,unsigned int);
115 
116  itkSetMacro(InputLayerId,unsigned int);
117  itkGetConstMacro(InputLayerId,unsigned int);
118 
119  itkSetMacro(OutputLayerId,unsigned int);
120  itkGetConstMacro(OutputLayerId,unsigned int);
121 
122 protected:
123 
124  WeightSetBase();
125  ~WeightSetBase() ITK_OVERRIDE;
126 
128  virtual void PrintSelf( std::ostream& os, Indent indent ) const ITK_OVERRIDE;
129 
130  typename RandomVariateGeneratorType::Pointer m_RandomGenerator;
131 
132  unsigned int m_NumberOfInputNodes;
133  unsigned int m_NumberOfOutputNodes;
134  vnl_matrix<ValueType> m_OutputValues;
135  vnl_matrix<ValueType> m_InputErrorValues;
136 
137  // weight updates dw=lr * del *y
138  // DW= current
139  // DW_m_1 = previous
140  // DW_m_2= second to last
141  // same applies for delta and bias values
142 
143  vnl_matrix<ValueType> m_DW; // delta valies for weight update
144  vnl_matrix<ValueType> m_DW_new; // delta valies for weight update
145  vnl_matrix<ValueType> m_DW_m_1; // delta valies for weight update
146  vnl_matrix<ValueType> m_DW_m_2; // delta valies for weight update
147  vnl_matrix<ValueType> m_DW_m; // delta valies for weight update
148 
149  vnl_vector<ValueType> m_DB; // delta values for bias update
150  vnl_vector<ValueType> m_DB_new; // delta values for bias update
151  vnl_vector<ValueType> m_DB_m_1; // delta values for bias update
152  vnl_vector<ValueType> m_DB_m_2; // delta values for bias update
153 
154  vnl_matrix<ValueType> m_Del; // dw=lr * del * y
155  vnl_matrix<ValueType> m_Del_new; // dw=lr * del * y
156  vnl_matrix<ValueType> m_Del_m_1; // dw=lr * del * y
157  vnl_matrix<ValueType> m_Del_m_2; // dw=lr * del * y
158 
159  vnl_vector<ValueType> m_Delb; // delta values for bias update
160  vnl_vector<ValueType> m_Delb_new; // delta values for bias update
161  vnl_vector<ValueType> m_Delb_m_1; // delta values for bias update
162  vnl_vector<ValueType> m_Delb_m_2; // delta values for bias update
163 
164  vnl_matrix<ValueType> m_InputLayerOutput;
165  vnl_matrix<ValueType> m_WeightMatrix; // composed of weights and a column
166  // of biases
167  vnl_matrix<int> m_ConnectivityMatrix;
168 
169  ValueType m_Momentum;
170  ValueType m_Bias;
171  bool m_FirstPass;
172  bool m_SecondPass;
173  ValueType m_Range;
174 
175  unsigned int m_InputLayerId;
176  unsigned int m_OutputLayerId;
177  unsigned int m_WeightSetId;
178 
179 }; //class
180 
181 } // end namespace Statistics
182 } // end namespace itk
183 
184 #ifndef ITK_MANUAL_INSTANTIATION
185 #include "itkWeightSetBase.hxx"
186 #endif
187 
188 
189 #endif
This is the itkWeightSetBase class.
MersenneTwisterRandomVariateGenerator RandomVariateGeneratorType
SmartPointer< const Self > ConstPointer
SmartPointer< Self > Pointer
TMeasurementVector::ValueType ValueType
LightProcessObject is the base class for all process objects (source, filters, mappers) in the Insigh...
Control indentation during Print() invocation.
Definition: itkIndent.h:49
const ValueType * ValueConstPointer