00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkBackPropagationLayerBase_h
00018 #define __itkBackPropagationLayerBase_h
00019
00020 #include "itkCompletelyConnectedWeightSet.h"
00021 #include "itkLayerBase.h"
00022 #include "itkObject.h"
00023 #include "itkMacro.h"
00024
00025 namespace itk
00026 {
00027 namespace Statistics
00028 {
00029 template<class TMeasurementVector, class TTargetVector>
00030 class BackPropagationLayer : public LayerBase<TMeasurementVector, TTargetVector>
00031 {
00032 public:
00033 typedef BackPropagationLayer Self;
00034 typedef LayerBase<TMeasurementVector, TTargetVector> Superclass;
00035 typedef SmartPointer<Self> Pointer;
00036 typedef SmartPointer<const Self> ConstPointer;
00037
00039 itkTypeMacro(BackPropagationLayer, LayerBase);
00040 itkNewMacro(Self);
00042
00043 typedef typename Superclass::ValueType ValueType;
00044 typedef vnl_vector<ValueType> NodeVectorType;
00045 typedef typename Superclass::InternalVectorType InternalVectorType;
00046 typedef typename Superclass::OutputVectorType OutputVectorType;
00047 typedef typename Superclass::LayerInterfaceType LayerInterfaceType;
00048 typedef CompletelyConnectedWeightSet<TMeasurementVector,TTargetVector> WeightSetType;
00049
00050 typedef typename Superclass::WeightSetInterfaceType WeightSetInterfaceType;
00051 typedef typename Superclass::InputFunctionInterfaceType InputFunctionInterfaceType;
00052 typedef typename Superclass::TransferFunctionInterfaceType TransferFunctionInterfaceType;
00053
00054 virtual void SetNumberOfNodes(unsigned int numNodes);
00055 virtual ValueType GetInputValue(unsigned int i) const;
00056 virtual void SetInputValue(unsigned int i, ValueType value);
00057
00058 virtual ValueType GetOutputValue(unsigned int) const;
00059 virtual void SetOutputValue(unsigned int, ValueType);
00060
00061 virtual ValueType * GetOutputVector();
00062 void SetOutputVector(TMeasurementVector value);
00063
00064 virtual void ForwardPropagate();
00065 virtual void ForwardPropagate(TMeasurementVector input);
00066
00067 virtual void BackwardPropagate();
00068 virtual void BackwardPropagate(InternalVectorType errors);
00069
00070 virtual void SetOutputErrorValues(TTargetVector);
00071 virtual ValueType GetOutputErrorValue(unsigned int node_id) const;
00072
00073 virtual ValueType GetInputErrorValue(unsigned int node_id) const;
00074 virtual ValueType * GetInputErrorVector();
00075 virtual void SetInputErrorValue(ValueType, unsigned int node_id);
00076
00077 virtual ValueType Activation(ValueType);
00078 virtual ValueType DActivation(ValueType);
00079
00081 itkSetMacro( Bias, ValueType );
00082 itkGetConstReferenceMacro( Bias, ValueType );
00084
00085 protected:
00086
00087 BackPropagationLayer();
00088 virtual ~BackPropagationLayer();
00089
00091 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00092
00093 private:
00094
00095 NodeVectorType m_NodeInputValues;
00096 NodeVectorType m_NodeOutputValues;
00097 NodeVectorType m_InputErrorValues;
00098 NodeVectorType m_OutputErrorValues;
00099 ValueType m_Bias;
00100 };
00101
00102 }
00103 }
00104
00105 #ifndef ITK_MANUAL_INSTANTIATION
00106 #include "itkBackPropagationLayer.txx"
00107 #endif
00108
00109 #endif
00110