00001 /*========================================================================= 00002 00003 Program: Insight Segmentation & Registration Toolkit 00004 Module: $RCSfile: itkMultilayerNeuralNetworkBase.h,v $ 00005 Language: C++ 00006 Date: $Date: 2007/01/19 20:39:21 $ 00007 Version: $Revision: 1.6 $ 00008 00009 Copyright (c) Insight Software Consortium. All rights reserved. 00010 See ITKCopyright.txt or http://www.itk.org/HTML/Copyright.htm for details. 00011 00012 This software is distributed WITHOUT ANY WARRANTY; without even 00013 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR 00014 PURPOSE. See the above copyright notices for more information. 00015 00016 =========================================================================*/ 00017 #ifndef __MultiLayerNeuralNetworkBase_h 00018 #define __MultiLayerNeuralNetworkBase_h 00019 00020 #include "itkNeuralNetworkObject.h" 00021 #include "itkErrorBackPropagationLearningFunctionBase.h" 00022 #include "itkErrorBackPropagationLearningWithMomentum.h" 00023 #include "itkQuickPropLearningRule.h" 00024 00025 namespace itk 00026 { 00027 namespace Statistics 00028 { 00029 00030 template<class TVector, class TOutput> 00031 class MultilayerNeuralNetworkBase : public NeuralNetworkObject<TVector, TOutput> 00032 { 00033 public: 00034 00035 typedef MultilayerNeuralNetworkBase Self; 00036 typedef NeuralNetworkObject<TVector, TOutput> Superclass; 00037 typedef SmartPointer<Self> Pointer; 00038 typedef SmartPointer<const Self> ConstPointer; 00039 itkTypeMacro(MultilayerNeuralNetworkBase, NeuralNetworkObject); 00040 00042 itkNewMacro( Self ); 00043 00044 typedef typename Superclass::ValueType ValueType; 00045 typedef typename Superclass::NetworkOutputType NetworkOutputType; 00046 typedef typename Superclass::LayerType LayerType; 00047 typedef typename Superclass::WeightSetType WeightSetType; 00048 typedef typename Superclass::WeightSetPointer WeightSetPointer; 00049 typedef typename Superclass::LayerPointer LayerPointer; 00050 typedef typename Superclass::LearningFunctionType LearningFunctionType; 00051 typedef typename Superclass::LearningFunctionPointer LearningFunctionPointer; 00052 00053 typedef std::vector<WeightSetPointer> WeightVectorType; 00054 typedef std::vector<LayerPointer> LayerVectorType; 00055 00056 itkSetMacro(NumOfLayers, int); 00057 itkGetConstReferenceMacro(NumOfLayers, int); 00058 00059 itkSetMacro(NumOfWeightSets, int); 00060 itkGetConstReferenceMacro(NumOfWeightSets, int); 00061 00062 void AddLayer(LayerType*); 00063 LayerType* GetLayer(int layer_id); 00064 const LayerType* GetLayer(int layer_id) const; 00065 00066 void AddWeightSet(WeightSetType*); 00067 WeightSetType* GetWeightSet(unsigned int id); 00068 const WeightSetType* GetWeightSet(unsigned int id) const; 00069 00070 void SetLearningFunction(LearningFunctionType* f); 00071 00072 // virtual ValueType* GenerateOutput(TVector samplevector); 00073 virtual NetworkOutputType GenerateOutput(TVector samplevector); 00074 00075 // virtual void BackwardPropagate(TOutput errors); 00076 virtual void BackwardPropagate(NetworkOutputType errors); 00077 00078 virtual void UpdateWeights(ValueType); 00079 00080 void SetLearningRule(LearningFunctionType*); 00081 00082 void SetLearningRate(ValueType learningrate); 00083 00084 void InitializeWeights(); 00085 00086 protected: 00087 MultilayerNeuralNetworkBase(); 00088 ~MultilayerNeuralNetworkBase(); 00089 00090 LayerVectorType m_Layers; 00091 WeightVectorType m_Weights; 00092 LearningFunctionPointer m_LearningFunction; 00093 ValueType m_LearningRate; 00094 int m_NumOfLayers; 00095 int m_NumOfWeightSets; 00097 virtual void PrintSelf( std::ostream& os, Indent indent ) const; 00098 }; 00099 00100 } // end namespace Statistics 00101 } // end namespace itk 00102 00103 #ifndef ITK_MANUAL_INSTANTIATION 00104 #include "itkMultilayerNeuralNetworkBase.txx" 00105 #endif 00106 00107 #endif 00108