ITK
4.1.0
Insight Segmentation and Registration Toolkit
|
00001 /*========================================================================= 00002 * 00003 * Copyright Insight Software Consortium 00004 * 00005 * Licensed under the Apache License, Version 2.0 (the "License"); 00006 * you may not use this file except in compliance with the License. 00007 * You may obtain a copy of the License at 00008 * 00009 * http://www.apache.org/licenses/LICENSE-2.0.txt 00010 * 00011 * Unless required by applicable law or agreed to in writing, software 00012 * distributed under the License is distributed on an "AS IS" BASIS, 00013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00014 * See the License for the specific language governing permissions and 00015 * limitations under the License. 00016 * 00017 *=========================================================================*/ 00018 #ifndef __itkMultilayerNeuralNetworkBase_h 00019 #define __itkMultilayerNeuralNetworkBase_h 00020 00021 #include "itkNeuralNetworkObject.h" 00022 00023 namespace itk 00024 { 00025 namespace Statistics 00026 { 00033 template<class TMeasurementVector, class TTargetVector,class TLearningLayer=LayerBase<TMeasurementVector, TTargetVector> > 00034 class MultilayerNeuralNetworkBase : public NeuralNetworkObject<TMeasurementVector, TTargetVector> 00035 { 00036 public: 00037 00038 typedef MultilayerNeuralNetworkBase Self; 00039 typedef NeuralNetworkObject<TMeasurementVector, TTargetVector> 00040 Superclass; 00041 typedef SmartPointer<Self> Pointer; 00042 typedef SmartPointer<const Self> ConstPointer; 00043 00044 itkTypeMacro(MultilayerNeuralNetworkBase, NeuralNetworkObject); 00045 00047 itkNewMacro( Self ); 00048 00049 typedef typename Superclass::ValueType ValueType; 00050 typedef typename Superclass::MeasurementVectorType MeasurementVectorType; 00051 typedef typename Superclass::TargetVectorType TargetVectorType; 00052 typedef typename Superclass::NetworkOutputType NetworkOutputType; 00053 00054 typedef typename Superclass::LayerInterfaceType LayerInterfaceType; 00055 00056 typedef TLearningLayer LearningLayerType; 00057 typedef LearningFunctionBase<typename TLearningLayer::LayerInterfaceType, TTargetVector> 00058 LearningFunctionInterfaceType; 00059 00060 typedef std::vector<typename LayerInterfaceType::WeightSetInterfaceType::Pointer> 00061 WeightVectorType; 00062 typedef std::vector<typename LayerInterfaceType::Pointer> 00063 LayerVectorType; 00064 00065 typedef TransferFunctionBase<ValueType> TransferFunctionInterfaceType; 00066 typedef InputFunctionBase<ValueType*, ValueType> InputFunctionInterfaceType; 00067 00068 //#define __USE_OLD_INTERFACE Comment out to ensure that new interface works 00069 #ifdef __USE_OLD_INTERFACE 00070 itkSetMacro(NumOfLayers, int); 00071 itkGetConstReferenceMacro(NumOfLayers, int); 00072 00073 itkSetMacro(NumOfWeightSets, int); 00074 itkGetConstReferenceMacro(NumOfWeightSets, int); 00075 #else 00076 int GetNumOfLayers(void) const 00077 { 00078 return m_Layers.size(); 00079 } 00080 int GetNumOfWeightSets(void) const 00081 { 00082 return m_Weights.size(); 00083 } 00084 00085 #endif 00086 00087 void AddLayer(LayerInterfaceType *); 00088 LayerInterfaceType * GetLayer(int layer_id); 00089 const LayerInterfaceType * GetLayer(int layer_id) const; 00090 00091 void AddWeightSet(typename LayerInterfaceType::WeightSetInterfaceType*); 00092 typename LayerInterfaceType::WeightSetInterfaceType* GetWeightSet(unsigned int id) 00093 { 00094 return m_Weights[id].GetPointer(); 00095 } 00096 #ifdef __USE_OLD_INTERFACE 00097 const typename LayerInterfaceType::WeightSetInterfaceType* GetWeightSet(unsigned int id) const; 00098 #endif 00099 00100 void SetLearningFunction(LearningFunctionInterfaceType* f); 00101 00102 virtual NetworkOutputType GenerateOutput(TMeasurementVector samplevector); 00103 00104 virtual void BackwardPropagate(NetworkOutputType errors); 00105 virtual void UpdateWeights(ValueType); 00106 00107 void SetLearningRule(LearningFunctionInterfaceType*); 00108 00109 void SetLearningRate(ValueType learningrate); 00110 00111 void InitializeWeights(); 00112 00113 protected: 00114 MultilayerNeuralNetworkBase(); 00115 ~MultilayerNeuralNetworkBase(); 00116 00117 LayerVectorType m_Layers; 00118 WeightVectorType m_Weights; 00119 typename LearningFunctionInterfaceType::Pointer m_LearningFunction; 00120 ValueType m_LearningRate; 00121 //#define __USE_OLD_INTERFACE Comment out to ensure that new interface works 00122 #ifdef __USE_OLD_INTERFACE 00123 //These are completely redundant variables that can be more reliably queried from 00124 // m_Layers->size() and m_Weights->size(); 00125 int m_NumOfLayers; 00126 int m_NumOfWeightSets; 00127 #endif 00128 00129 virtual void PrintSelf( std::ostream& os, Indent indent ) const; 00130 }; 00131 00132 } // end namespace Statistics 00133 } // end namespace itk 00134 00135 #ifndef ITK_MANUAL_INSTANTIATION 00136 #include "itkMultilayerNeuralNetworkBase.hxx" 00137 #endif 00138 00139 #endif 00140