00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016 #ifndef __MultiLayerNeuralNetworkBase_h
00017 #define __MultiLayerNeuralNetworkBase_h
00018
00019 #include "itkNeuralNetworkObject.h"
00020 #include "itkLayerBase.h"
00021
00022 namespace itk
00023 {
00024 namespace Statistics
00025 {
00026
00027 template<class TMeasurementVector, class TTargetVector,class TLearningLayer=LayerBase<TMeasurementVector, TTargetVector> >
00028 class MultilayerNeuralNetworkBase : public NeuralNetworkObject<TMeasurementVector, TTargetVector>
00029 {
00030 public:
00031
00032 typedef MultilayerNeuralNetworkBase Self;
00033 typedef NeuralNetworkObject<TMeasurementVector, TTargetVector> Superclass;
00034 typedef SmartPointer<Self> Pointer;
00035 typedef SmartPointer<const Self> ConstPointer;
00036 itkTypeMacro(MultilayerNeuralNetworkBase, NeuralNetworkObject);
00037
00039 itkNewMacro( Self );
00040
00041 typedef typename Superclass::ValueType ValueType;
00042 typedef typename Superclass::MeasurementVectorType MeasurementVectorType;
00043 typedef typename Superclass::TargetVectorType TargetVectorType;
00044 typedef typename Superclass::NetworkOutputType NetworkOutputType;
00045
00046 typedef typename Superclass::LayerInterfaceType LayerInterfaceType;
00047 typedef TLearningLayer LearningLayerType;
00048 typedef LearningFunctionBase<typename TLearningLayer::LayerInterfaceType, TTargetVector> LearningFunctionInterfaceType;
00049
00050 typedef std::vector<typename LayerInterfaceType::WeightSetInterfaceType::Pointer> WeightVectorType;
00051 typedef std::vector<typename LayerInterfaceType::Pointer> LayerVectorType;
00052
00053 typedef TransferFunctionBase<ValueType> TransferFunctionInterfaceType;
00054 typedef InputFunctionBase<ValueType*, ValueType> InputFunctionInterfaceType;
00055
00056
00057 #ifdef __USE_OLD_INTERFACE
00058 itkSetMacro(NumOfLayers, int);
00059 itkGetConstReferenceMacro(NumOfLayers, int);
00060
00061 itkSetMacro(NumOfWeightSets, int);
00062 itkGetConstReferenceMacro(NumOfWeightSets, int);
00063 #else
00064 int GetNumOfLayers(void) const
00065 {
00066 return m_Layers.size();
00067 }
00068 int GetNumOfWeightSets(void) const
00069 {
00070 return m_Weights.size();
00071 }
00072
00073 #endif
00074
00075 void AddLayer(LayerInterfaceType *);
00076 LayerInterfaceType * GetLayer(int layer_id);
00077 const LayerInterfaceType * GetLayer(int layer_id) const;
00078
00079 void AddWeightSet(typename LayerInterfaceType::WeightSetInterfaceType*);
00080 typename LayerInterfaceType::WeightSetInterfaceType* GetWeightSet(unsigned int id)
00081 {
00082 return m_Weights[id].GetPointer();
00083 }
00084 #ifdef __USE_OLD_INTERFACE
00085 const typename LayerInterfaceType::WeightSetInterfaceType* GetWeightSet(unsigned int id) const;
00086 #endif
00087
00088 void SetLearningFunction(LearningFunctionInterfaceType* f);
00089
00090 virtual NetworkOutputType GenerateOutput(TMeasurementVector samplevector);
00091
00092 virtual void BackwardPropagate(NetworkOutputType errors);
00093 virtual void UpdateWeights(ValueType);
00094
00095 void SetLearningRule(LearningFunctionInterfaceType*);
00096
00097 void SetLearningRate(ValueType learningrate);
00098
00099 void InitializeWeights();
00100
00101 protected:
00102 MultilayerNeuralNetworkBase();
00103 ~MultilayerNeuralNetworkBase();
00104
00105 LayerVectorType m_Layers;
00106 WeightVectorType m_Weights;
00107 typename LearningFunctionInterfaceType::Pointer m_LearningFunction;
00108 ValueType m_LearningRate;
00109
00110 #ifdef __USE_OLD_INTERFACE
00111
00112
00113 int m_NumOfLayers;
00114 int m_NumOfWeightSets;
00115 #endif
00116
00117 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00118 };
00119
00120 }
00121 }
00122
00123 #ifndef ITK_MANUAL_INSTANTIATION
00124 #include "itkMultilayerNeuralNetworkBase.txx"
00125 #endif
00126
00127 #endif
00128