ITK  4.4.0
Insight Segmentation and Registration Toolkit
itkWeightSetBase.h
Go to the documentation of this file.
1 /*=========================================================================
2  *
3  * Copyright Insight Software Consortium
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  *=========================================================================*/
18 #ifndef __itkWeightSetBase_h
19 #define __itkWeightSetBase_h
20 
21 #include "itkLightProcessObject.h"
22 #include "vnl/vnl_matrix.h"
23 #include "vnl/vnl_diag_matrix.h"
24 #include "itkVector.h"
26 #include <cmath>
27 #include <cstdlib>
28 
29 namespace itk
30 {
31 namespace Statistics
32 {
39 template<class TMeasurementVector, class TTargetVector>
41 {
42 public:
43 
49 
51 
52  typedef typename TMeasurementVector::ValueType ValueType;
54  typedef const ValueType* ValueConstPointer;
55 
56  void Initialize();
57 
59 
60  virtual void ForwardPropagate(ValuePointer inputlayeroutputvalues);
61  virtual void BackwardPropagate(ValuePointer inputerror);
62 
63  void SetConnectivityMatrix(vnl_matrix < int>);
64 
65  void SetNumberOfInputNodes(unsigned int n);
66  unsigned int GetNumberOfInputNodes() const;
67 
68  void SetNumberOfOutputNodes(unsigned int n);
69  unsigned int GetNumberOfOutputNodes() const;
70 
71  void SetRange(ValueType Range);
72 
73  virtual ValuePointer GetOutputValues();
74  virtual ValuePointer GetInputValues();
75 
78 
80 
95 
96  void SetWeightValues(ValuePointer weights);
97  virtual void UpdateWeights(ValueType LearningRate);
98 
99  itkSetMacro( Momentum, ValueType );
100  itkGetConstReferenceMacro( Momentum, ValueType );
101 
102  itkSetMacro( Bias, ValueType );
103  itkGetConstReferenceMacro( Bias, ValueType );
104 
105  itkSetMacro( FirstPass, bool );
106  itkGetConstMacro( FirstPass, bool );
107 
108  itkSetMacro( SecondPass, bool );
109  itkGetConstMacro( SecondPass, bool );
110 
111  void InitializeWeights();
112 
113  itkSetMacro(WeightSetId,unsigned int);
114  itkGetConstMacro(WeightSetId,unsigned int);
115 
116  itkSetMacro(InputLayerId,unsigned int);
117  itkGetConstMacro(InputLayerId,unsigned int);
118 
119  itkSetMacro(OutputLayerId,unsigned int);
120  itkGetConstMacro(OutputLayerId,unsigned int);
121 
122 protected:
123 
124  WeightSetBase();
125  ~WeightSetBase();
126 
128  virtual void PrintSelf( std::ostream& os, Indent indent ) const;
129 
131 
132  unsigned int m_NumberOfInputNodes;
133  unsigned int m_NumberOfOutputNodes;
134  vnl_matrix<ValueType> m_OutputValues;
135  vnl_matrix<ValueType> m_InputErrorValues;
136 
137  // weight updates dw=lr * del *y
138  // DW= current
139  // DW_m_1 = previous
140  // DW_m_2= second to last
141  // same applies for delta and bias values
142 
143  vnl_matrix<ValueType> m_DW; // delta valies for weight update
144  vnl_matrix<ValueType> m_DW_new; // delta valies for weight update
145  vnl_matrix<ValueType> m_DW_m_1; // delta valies for weight update
146  vnl_matrix<ValueType> m_DW_m_2; // delta valies for weight update
147  vnl_matrix<ValueType> m_DW_m; // delta valies for weight update
148 
149  vnl_vector<ValueType> m_DB; // delta values for bias update
150  vnl_vector<ValueType> m_DB_new; // delta values for bias update
151  vnl_vector<ValueType> m_DB_m_1; // delta values for bias update
152  vnl_vector<ValueType> m_DB_m_2; // delta values for bias update
153 
154  vnl_matrix<ValueType> m_Del; // dw=lr * del * y
155  vnl_matrix<ValueType> m_Del_new; // dw=lr * del * y
156  vnl_matrix<ValueType> m_Del_m_1; // dw=lr * del * y
157  vnl_matrix<ValueType> m_Del_m_2; // dw=lr * del * y
158 
159  vnl_vector<ValueType> m_Delb; // delta values for bias update
160  vnl_vector<ValueType> m_Delb_new; // delta values for bias update
161  vnl_vector<ValueType> m_Delb_m_1; // delta values for bias update
162  vnl_vector<ValueType> m_Delb_m_2; // delta values for bias update
163 
164  vnl_matrix<ValueType> m_InputLayerOutput;
165  vnl_matrix<ValueType> m_WeightMatrix; // composed of weights and a column
166  // of biases
167  vnl_matrix<int> m_ConnectivityMatrix;
168 
174 
175  unsigned int m_InputLayerId;
176  unsigned int m_OutputLayerId;
177  unsigned int m_WeightSetId;
178 
179 }; //class
180 
181 } // end namespace Statistics
182 } // end namespace itk
183 
184 #ifndef ITK_MANUAL_INSTANTIATION
185 #include "itkWeightSetBase.hxx"
186 #endif
187 
188 
189 #endif
190