00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkMultilayerNeuralNetworkBase_h
00018 #define __itkMultilayerNeuralNetworkBase_h
00019
00020 #include "itkNeuralNetworkObject.h"
00021 #include "itkLayerBase.h"
00022
00023 namespace itk
00024 {
00025 namespace Statistics
00026 {
00027
00028 template<class TMeasurementVector, class TTargetVector,class TLearningLayer=LayerBase<TMeasurementVector, TTargetVector> >
00029 class MultilayerNeuralNetworkBase : public NeuralNetworkObject<TMeasurementVector, TTargetVector>
00030 {
00031 public:
00032
00033 typedef MultilayerNeuralNetworkBase Self;
00034 typedef NeuralNetworkObject<TMeasurementVector, TTargetVector>
00035 Superclass;
00036 typedef SmartPointer<Self> Pointer;
00037 typedef SmartPointer<const Self> ConstPointer;
00038
00039 itkTypeMacro(MultilayerNeuralNetworkBase, NeuralNetworkObject);
00040
00042 itkNewMacro( Self );
00043
00044 typedef typename Superclass::ValueType ValueType;
00045 typedef typename Superclass::MeasurementVectorType MeasurementVectorType;
00046 typedef typename Superclass::TargetVectorType TargetVectorType;
00047 typedef typename Superclass::NetworkOutputType NetworkOutputType;
00048
00049 typedef typename Superclass::LayerInterfaceType LayerInterfaceType;
00050
00051 typedef TLearningLayer LearningLayerType;
00052 typedef LearningFunctionBase<typename TLearningLayer::LayerInterfaceType, TTargetVector>
00053 LearningFunctionInterfaceType;
00054
00055 typedef std::vector<typename LayerInterfaceType::WeightSetInterfaceType::Pointer>
00056 WeightVectorType;
00057 typedef std::vector<typename LayerInterfaceType::Pointer>
00058 LayerVectorType;
00059
00060 typedef TransferFunctionBase<ValueType> TransferFunctionInterfaceType;
00061 typedef InputFunctionBase<ValueType*, ValueType> InputFunctionInterfaceType;
00062
00063
00064 #ifdef __USE_OLD_INTERFACE
00065 itkSetMacro(NumOfLayers, int);
00066 itkGetConstReferenceMacro(NumOfLayers, int);
00067
00068 itkSetMacro(NumOfWeightSets, int);
00069 itkGetConstReferenceMacro(NumOfWeightSets, int);
00070 #else
00071 int GetNumOfLayers(void) const
00072 {
00073 return m_Layers.size();
00074 }
00075 int GetNumOfWeightSets(void) const
00076 {
00077 return m_Weights.size();
00078 }
00079
00080 #endif
00081
00082 void AddLayer(LayerInterfaceType *);
00083 LayerInterfaceType * GetLayer(int layer_id);
00084 const LayerInterfaceType * GetLayer(int layer_id) const;
00085
00086 void AddWeightSet(typename LayerInterfaceType::WeightSetInterfaceType*);
00087 typename LayerInterfaceType::WeightSetInterfaceType* GetWeightSet(unsigned int id)
00088 {
00089 return m_Weights[id].GetPointer();
00090 }
00091 #ifdef __USE_OLD_INTERFACE
00092 const typename LayerInterfaceType::WeightSetInterfaceType* GetWeightSet(unsigned int id) const;
00093 #endif
00094
00095 void SetLearningFunction(LearningFunctionInterfaceType* f);
00096
00097 virtual NetworkOutputType GenerateOutput(TMeasurementVector samplevector);
00098
00099 virtual void BackwardPropagate(NetworkOutputType errors);
00100 virtual void UpdateWeights(ValueType);
00101
00102 void SetLearningRule(LearningFunctionInterfaceType*);
00103
00104 void SetLearningRate(ValueType learningrate);
00105
00106 void InitializeWeights();
00107
00108 protected:
00109 MultilayerNeuralNetworkBase();
00110 ~MultilayerNeuralNetworkBase();
00111
00112 LayerVectorType m_Layers;
00113 WeightVectorType m_Weights;
00114 typename LearningFunctionInterfaceType::Pointer m_LearningFunction;
00115 ValueType m_LearningRate;
00116
00117 #ifdef __USE_OLD_INTERFACE
00118
00119
00120 int m_NumOfLayers;
00121 int m_NumOfWeightSets;
00122 #endif
00123
00124 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00125 };
00126
00127 }
00128 }
00129
00130 #ifndef ITK_MANUAL_INSTANTIATION
00131 #include "itkMultilayerNeuralNetworkBase.txx"
00132 #endif
00133
00134 #endif
00135