ITK  4.9.0
Insight Segmentation and Registration Toolkit
itkWeightSetBase.h
Go to the documentation of this file.
1 /*=========================================================================
2  *
3  * Copyright Insight Software Consortium
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  *=========================================================================*/
18 #ifndef itkWeightSetBase_h
19 #define itkWeightSetBase_h
20 
21 #include "itkLightProcessObject.h"
22 #include "vnl/vnl_matrix.h"
23 #include "vnl/vnl_diag_matrix.h"
24 #include "itkVector.h"
26 #include <cmath>
27 #include <cstdlib>
28 
29 namespace itk
30 {
31 namespace Statistics
32 {
39 template<typename TMeasurementVector, typename TTargetVector>
41 {
42 public:
43 
49 
51 
52  typedef typename TMeasurementVector::ValueType ValueType;
54  typedef const ValueType* ValueConstPointer;
55 
56  void Initialize();
57 
59 
60  virtual void ForwardPropagate(ValuePointer inputlayeroutputvalues);
61  virtual void BackwardPropagate(ValuePointer inputerror);
62 
63  void SetConnectivityMatrix(vnl_matrix < int>);
64 
65  void SetNumberOfInputNodes(unsigned int n);
66  unsigned int GetNumberOfInputNodes() const;
67 
68  void SetNumberOfOutputNodes(unsigned int n);
69  unsigned int GetNumberOfOutputNodes() const;
70 
71  void SetRange(ValueType Range);
72 
73  virtual ValuePointer GetOutputValues();
74  virtual ValuePointer GetInputValues();
75 
78 
80 
95 
96  void SetWeightValues(ValuePointer weights);
97  virtual void UpdateWeights(ValueType LearningRate);
98 
99  itkSetMacro( Momentum, ValueType );
100  itkGetConstReferenceMacro( Momentum, ValueType );
101 
102  itkSetMacro( Bias, ValueType );
103  itkGetConstReferenceMacro( Bias, ValueType );
104 
105  itkSetMacro( FirstPass, bool );
106  itkGetConstMacro( FirstPass, bool );
107 
108  itkSetMacro( SecondPass, bool );
109  itkGetConstMacro( SecondPass, bool );
110 
111  void InitializeWeights();
112 
113  itkSetMacro(WeightSetId,unsigned int);
114  itkGetConstMacro(WeightSetId,unsigned int);
115 
116  itkSetMacro(InputLayerId,unsigned int);
117  itkGetConstMacro(InputLayerId,unsigned int);
118 
119  itkSetMacro(OutputLayerId,unsigned int);
120  itkGetConstMacro(OutputLayerId,unsigned int);
121 
122 protected:
123 
124  WeightSetBase();
125  ~WeightSetBase();
126 
128  virtual void PrintSelf( std::ostream& os, Indent indent ) const ITK_OVERRIDE;
129 
131 
132  unsigned int m_NumberOfInputNodes;
133  unsigned int m_NumberOfOutputNodes;
134  vnl_matrix<ValueType> m_OutputValues;
135  vnl_matrix<ValueType> m_InputErrorValues;
136 
137  // weight updates dw=lr * del *y
138  // DW= current
139  // DW_m_1 = previous
140  // DW_m_2= second to last
141  // same applies for delta and bias values
142 
143  vnl_matrix<ValueType> m_DW; // delta valies for weight update
144  vnl_matrix<ValueType> m_DW_new; // delta valies for weight update
145  vnl_matrix<ValueType> m_DW_m_1; // delta valies for weight update
146  vnl_matrix<ValueType> m_DW_m_2; // delta valies for weight update
147  vnl_matrix<ValueType> m_DW_m; // delta valies for weight update
148 
149  vnl_vector<ValueType> m_DB; // delta values for bias update
150  vnl_vector<ValueType> m_DB_new; // delta values for bias update
151  vnl_vector<ValueType> m_DB_m_1; // delta values for bias update
152  vnl_vector<ValueType> m_DB_m_2; // delta values for bias update
153 
154  vnl_matrix<ValueType> m_Del; // dw=lr * del * y
155  vnl_matrix<ValueType> m_Del_new; // dw=lr * del * y
156  vnl_matrix<ValueType> m_Del_m_1; // dw=lr * del * y
157  vnl_matrix<ValueType> m_Del_m_2; // dw=lr * del * y
158 
159  vnl_vector<ValueType> m_Delb; // delta values for bias update
160  vnl_vector<ValueType> m_Delb_new; // delta values for bias update
161  vnl_vector<ValueType> m_Delb_m_1; // delta values for bias update
162  vnl_vector<ValueType> m_Delb_m_2; // delta values for bias update
163 
164  vnl_matrix<ValueType> m_InputLayerOutput;
165  vnl_matrix<ValueType> m_WeightMatrix; // composed of weights and a column
166  // of biases
167  vnl_matrix<int> m_ConnectivityMatrix;
168 
174 
175  unsigned int m_InputLayerId;
176  unsigned int m_OutputLayerId;
177  unsigned int m_WeightSetId;
178 
179 }; //class
180 
181 } // end namespace Statistics
182 } // end namespace itk
183 
184 #ifndef ITK_MANUAL_INSTANTIATION
185 #include "itkWeightSetBase.hxx"
186 #endif
187 
188 
189 #endif
virtual void PrintSelf(std::ostream &os, Indent indent) const override
vnl_vector< ValueType > m_DB_new
vnl_matrix< ValueType > m_DW_new
This is the itkWeightSetBase class.
vnl_vector< ValueType > m_DB
vnl_matrix< ValueType > m_OutputValues
MersenneTwisterRandomVariateGenerator RandomVariateGeneratorType
ValuePointer GetTotalDeltaBValues()
ValuePointer GetPrev_m_2DWValues()
vnl_matrix< ValueType > m_DW_m
ValuePointer GetPrev_m_2DeltaValues()
virtual void BackwardPropagate(ValuePointer inputerror)
vnl_matrix< ValueType > m_InputErrorValues
vnl_vector< ValueType > m_DB_m_1
void SetConnectivityMatrix(vnl_matrix< int >)
void SetDeltaBValues(ValuePointer)
virtual void UpdateWeights(ValueType LearningRate)
SmartPointer< const Self > ConstPointer
vnl_matrix< ValueType > m_Del_new
ValueType RandomWeightValue(ValueType low, ValueType high)
vnl_matrix< ValueType > m_InputLayerOutput
vnl_matrix< ValueType > m_Del_m_2
void SetDWValues(ValuePointer)
virtual ValuePointer GetOutputValues()
ValuePointer GetPrevDeltaBValues()
vnl_matrix< ValueType > m_Del_m_1
void SetNumberOfOutputNodes(unsigned int n)
vnl_vector< ValueType > m_Delb_m_2
vnl_matrix< ValueType > m_DW_m_1
void SetNumberOfInputNodes(unsigned int n)
vnl_vector< ValueType > m_Delb_m_1
vnl_vector< ValueType > m_Delb_new
SmartPointer< Self > Pointer
unsigned int GetNumberOfOutputNodes() const
void SetRange(ValueType Range)
vnl_matrix< ValueType > m_Del
RandomVariateGeneratorType::Pointer m_RandomGenerator
vnl_matrix< int > m_ConnectivityMatrix
vnl_vector< ValueType > m_DB_m_2
TMeasurementVector::ValueType ValueType
unsigned int GetNumberOfInputNodes() const
LightProcessObject is the base class for all process objects (source, filters, mappers) in the Insigh...
ValuePointer GetTotalDeltaValues()
Control indentation during Print() invocation.
Definition: itkIndent.h:49
vnl_matrix< ValueType > m_DW
virtual void ForwardPropagate(ValuePointer inputlayeroutputvalues)
void SetDeltaValues(ValuePointer)
void SetDBValues(ValuePointer)
vnl_vector< ValueType > m_Delb
vnl_matrix< ValueType > m_WeightMatrix
const ValueType * ValueConstPointer
void SetWeightValues(ValuePointer weights)
virtual ValuePointer GetInputValues()
vnl_matrix< ValueType > m_DW_m_2