ITK  4.6.0
Insight Segmentation and Registration Toolkit
itkMultilayerNeuralNetworkBase.h
Go to the documentation of this file.
1 /*=========================================================================
2  *
3  * Copyright Insight Software Consortium
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  *=========================================================================*/
18 #ifndef __itkMultilayerNeuralNetworkBase_h
19 #define __itkMultilayerNeuralNetworkBase_h
20 
21 #include "itkNeuralNetworkObject.h"
22 
23 namespace itk
24 {
25 namespace Statistics
26 {
33 template<typename TMeasurementVector, typename TTargetVector,typename TLearningLayer=LayerBase<TMeasurementVector, TTargetVector> >
34 class MultilayerNeuralNetworkBase : public NeuralNetworkObject<TMeasurementVector, TTargetVector>
35 {
36 public:
37 
43 
45 
47  itkNewMacro( Self );
48 
49  typedef typename Superclass::ValueType ValueType;
53 
55 
56  typedef TLearningLayer LearningLayerType;
59 
60  typedef std::vector<typename LayerInterfaceType::WeightSetInterfaceType::Pointer>
62  typedef std::vector<typename LayerInterfaceType::Pointer>
64 
67 
68 //#define __USE_OLD_INTERFACE Comment out to ensure that new interface works
69 #ifdef __USE_OLD_INTERFACE
70  itkSetMacro(NumOfLayers, int);
71  itkGetConstReferenceMacro(NumOfLayers, int);
72 
73  itkSetMacro(NumOfWeightSets, int);
74  itkGetConstReferenceMacro(NumOfWeightSets, int);
75 #else
76  int GetNumOfLayers(void) const
77  {
78  return m_Layers.size();
79  }
80  int GetNumOfWeightSets(void) const
81  {
82  return m_Weights.size();
83  }
84 
85 #endif
86 
88  LayerInterfaceType * GetLayer(int layer_id);
89  const LayerInterfaceType * GetLayer(int layer_id) const;
90 
91  void AddWeightSet(typename LayerInterfaceType::WeightSetInterfaceType*);
93  {
94  return m_Weights[id].GetPointer();
95  }
96 #ifdef __USE_OLD_INTERFACE
97  const typename LayerInterfaceType::WeightSetInterfaceType* GetWeightSet(unsigned int id) const;
98 #endif
99 
101 
102  virtual NetworkOutputType GenerateOutput(TMeasurementVector samplevector);
103 
104  virtual void BackwardPropagate(NetworkOutputType errors);
105  virtual void UpdateWeights(ValueType);
106 
108 
109  void SetLearningRate(ValueType learningrate);
110 
111  void InitializeWeights();
112 
113 protected:
116 
121  //#define __USE_OLD_INTERFACE Comment out to ensure that new interface works
122 #ifdef __USE_OLD_INTERFACE
123  //These are completely redundant variables that can be more reliably queried from
124  // m_Layers->size() and m_Weights->size();
125  int m_NumOfLayers;
126  int m_NumOfWeightSets;
127 #endif
128 
129  virtual void PrintSelf( std::ostream& os, Indent indent ) const;
130 };
131 
132 } // end namespace Statistics
133 } // end namespace itk
134 
135 #ifndef ITK_MANUAL_INSTANTIATION
136 #include "itkMultilayerNeuralNetworkBase.hxx"
137 #endif
138 
139 #endif
LayerBase< TMeasurementVector, TTargetVector > LayerInterfaceType
virtual NetworkOutputType GenerateOutput(TMeasurementVector samplevector)
void SetLearningRate(ValueType learningrate)
This is the itkWeightSetBase class.
std::vector< typename LayerInterfaceType::WeightSetInterfaceType::Pointer > WeightVectorType
This is the itkTransferFunctionBase class.
InputFunctionBase< ValueType *, ValueType > InputFunctionInterfaceType
void AddWeightSet(typename LayerInterfaceType::WeightSetInterfaceType *)
This is the itkInputFunctionBase class.
std::vector< typename LayerInterfaceType::Pointer > LayerVectorType
LayerInterfaceType * GetLayer(int layer_id)
void SetLearningFunction(LearningFunctionInterfaceType *f)
void SetLearningRule(LearningFunctionInterfaceType *)
virtual void PrintSelf(std::ostream &os, Indent indent) const
LearningFunctionInterfaceType::Pointer m_LearningFunction
MeasurementVectorType::ValueType ValueType
virtual void BackwardPropagate(NetworkOutputType errors)
This is the itkMultilayerNeuralNetworkBase class.
TransferFunctionBase< ValueType > TransferFunctionInterfaceType
This is the itkNeuralNetworkObject class.
Control indentation during Print() invocation.
Definition: itkIndent.h:49
The LearningFunctionBase is the base class for all the learning strategies.
NeuralNetworkObject< TMeasurementVector, TTargetVector > Superclass
LayerInterfaceType::WeightSetInterfaceType * GetWeightSet(unsigned int id)
Base class for all data objects in ITK.
LearningFunctionBase< typename TLearningLayer::LayerInterfaceType, TTargetVector > LearningFunctionInterfaceType