ITK  4.0.0
Insight Segmentation and Registration Toolkit
itkMultilayerNeuralNetworkBase.h
Go to the documentation of this file.
00001 /*=========================================================================
00002  *
00003  *  Copyright Insight Software Consortium
00004  *
00005  *  Licensed under the Apache License, Version 2.0 (the "License");
00006  *  you may not use this file except in compliance with the License.
00007  *  You may obtain a copy of the License at
00008  *
00009  *         http://www.apache.org/licenses/LICENSE-2.0.txt
00010  *
00011  *  Unless required by applicable law or agreed to in writing, software
00012  *  distributed under the License is distributed on an "AS IS" BASIS,
00013  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00014  *  See the License for the specific language governing permissions and
00015  *  limitations under the License.
00016  *
00017  *=========================================================================*/
00018 #ifndef __itkMultilayerNeuralNetworkBase_h
00019 #define __itkMultilayerNeuralNetworkBase_h
00020 
00021 #include "itkNeuralNetworkObject.h"
00022 
00023 namespace itk
00024 {
00025 namespace Statistics
00026 {
00033 template<class TMeasurementVector, class TTargetVector,class TLearningLayer=LayerBase<TMeasurementVector, TTargetVector> >
00034 class MultilayerNeuralNetworkBase : public NeuralNetworkObject<TMeasurementVector, TTargetVector>
00035 {
00036 public:
00037 
00038   typedef MultilayerNeuralNetworkBase  Self;
00039   typedef NeuralNetworkObject<TMeasurementVector, TTargetVector>
00040                                        Superclass;
00041   typedef SmartPointer<Self>           Pointer;
00042   typedef SmartPointer<const Self>     ConstPointer;
00043 
00044   itkTypeMacro(MultilayerNeuralNetworkBase, NeuralNetworkObject);
00045 
00047   itkNewMacro( Self );
00048 
00049   typedef typename Superclass::ValueType             ValueType;
00050   typedef typename Superclass::MeasurementVectorType MeasurementVectorType;
00051   typedef typename Superclass::TargetVectorType      TargetVectorType;
00052   typedef typename Superclass::NetworkOutputType     NetworkOutputType;
00053 
00054   typedef typename Superclass::LayerInterfaceType        LayerInterfaceType;
00055 
00056   typedef TLearningLayer                                 LearningLayerType;
00057   typedef LearningFunctionBase<typename TLearningLayer::LayerInterfaceType, TTargetVector>
00058                                                          LearningFunctionInterfaceType;
00059 
00060   typedef std::vector<typename LayerInterfaceType::WeightSetInterfaceType::Pointer>
00061                                                          WeightVectorType;
00062         typedef std::vector<typename LayerInterfaceType::Pointer>
00063                                                          LayerVectorType;
00064 
00065         typedef TransferFunctionBase<ValueType>          TransferFunctionInterfaceType;
00066         typedef InputFunctionBase<ValueType*, ValueType> InputFunctionInterfaceType;
00067 
00068 //#define __USE_OLD_INTERFACE  Comment out to ensure that new interface works
00069 #ifdef __USE_OLD_INTERFACE
00070   itkSetMacro(NumOfLayers, int);
00071   itkGetConstReferenceMacro(NumOfLayers, int);
00072 
00073   itkSetMacro(NumOfWeightSets, int);
00074   itkGetConstReferenceMacro(NumOfWeightSets, int);
00075 #else
00076   int GetNumOfLayers(void) const
00077     {
00078     return m_Layers.size();
00079     }
00080   int GetNumOfWeightSets(void) const
00081     {
00082     return m_Weights.size();
00083     }
00084 
00085 #endif
00086 
00087   void AddLayer(LayerInterfaceType *);
00088   LayerInterfaceType * GetLayer(int layer_id);
00089   const LayerInterfaceType * GetLayer(int layer_id) const;
00090 
00091   void AddWeightSet(typename LayerInterfaceType::WeightSetInterfaceType*);
00092   typename LayerInterfaceType::WeightSetInterfaceType* GetWeightSet(unsigned int id)
00093     {
00094     return m_Weights[id].GetPointer();
00095     }
00096 #ifdef __USE_OLD_INTERFACE
00097   const typename LayerInterfaceType::WeightSetInterfaceType* GetWeightSet(unsigned int id) const;
00098 #endif
00099 
00100   void SetLearningFunction(LearningFunctionInterfaceType* f);
00101 
00102   virtual NetworkOutputType GenerateOutput(TMeasurementVector samplevector);
00103 
00104   virtual void BackwardPropagate(NetworkOutputType errors);
00105   virtual void UpdateWeights(ValueType);
00106 
00107   void SetLearningRule(LearningFunctionInterfaceType*);
00108 
00109   void SetLearningRate(ValueType learningrate);
00110 
00111   void InitializeWeights();
00112 
00113 protected:
00114   MultilayerNeuralNetworkBase();
00115   ~MultilayerNeuralNetworkBase();
00116 
00117   LayerVectorType                                   m_Layers;
00118   WeightVectorType                                  m_Weights;
00119   typename LearningFunctionInterfaceType::Pointer   m_LearningFunction;
00120   ValueType                                         m_LearningRate;
00121   //#define __USE_OLD_INTERFACE  Comment out to ensure that new interface works
00122 #ifdef __USE_OLD_INTERFACE
00123   //These are completely redundant variables that can be more reliably queried from
00124   // m_Layers->size() and m_Weights->size();
00125   int                             m_NumOfLayers;
00126   int                             m_NumOfWeightSets;
00127 #endif
00128 
00129   virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00130 };
00131 
00132 } // end namespace Statistics
00133 } // end namespace itk
00134 
00135 #ifndef ITK_MANUAL_INSTANTIATION
00136 #include "itkMultilayerNeuralNetworkBase.hxx"
00137 #endif
00138 
00139 #endif
00140