00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkBackPropagationLayer_h
00018 #define __itkBackPropagationLayer_h
00019
00020 #include "itkCompletelyConnectedWeightSet.h"
00021 #include "itkLayerBase.h"
00022 #include "itkObject.h"
00023 #include "itkMacro.h"
00024
00025 namespace itk
00026 {
00027 namespace Statistics
00028 {
00029 template<class TMeasurementVector, class TTargetVector>
00030 class BackPropagationLayer : public LayerBase<TMeasurementVector, TTargetVector>
00031 {
00032 public:
00033 typedef BackPropagationLayer Self;
00034 typedef LayerBase<TMeasurementVector, TTargetVector> Superclass;
00035 typedef SmartPointer<Self> Pointer;
00036 typedef SmartPointer<const Self> ConstPointer;
00037
00039 itkTypeMacro(BackPropagationLayer, LayerBase);
00040 itkNewMacro(Self);
00042
00043 typedef typename Superclass::ValueType ValueType;
00044 typedef vnl_vector<ValueType> NodeVectorType;
00045 typedef typename Superclass::InternalVectorType InternalVectorType;
00046 typedef typename Superclass::OutputVectorType OutputVectorType;
00047 typedef typename Superclass::LayerInterfaceType LayerInterfaceType;
00048 typedef CompletelyConnectedWeightSet<TMeasurementVector,TTargetVector>
00049 WeightSetType;
00050
00051 typedef typename Superclass::WeightSetInterfaceType WeightSetInterfaceType;
00052 typedef typename Superclass::InputFunctionInterfaceType
00053 InputFunctionInterfaceType;
00054 typedef typename Superclass::TransferFunctionInterfaceType
00055 TransferFunctionInterfaceType;
00056
00057 virtual void SetNumberOfNodes(unsigned int numNodes);
00058 virtual ValueType GetInputValue(unsigned int i) const;
00059 virtual void SetInputValue(unsigned int i, ValueType value);
00060
00061 virtual ValueType GetOutputValue(unsigned int) const;
00062 virtual void SetOutputValue(unsigned int, ValueType);
00063
00064 virtual ValueType * GetOutputVector();
00065 void SetOutputVector(TMeasurementVector value);
00066
00067 virtual void ForwardPropagate();
00068 virtual void ForwardPropagate(TMeasurementVector input);
00069
00070 virtual void BackwardPropagate();
00071 virtual void BackwardPropagate(InternalVectorType errors);
00072
00073 virtual void SetOutputErrorValues(TTargetVector);
00074 virtual ValueType GetOutputErrorValue(unsigned int node_id) const;
00075
00076 virtual ValueType GetInputErrorValue(unsigned int node_id) const;
00077 virtual ValueType * GetInputErrorVector();
00078 virtual void SetInputErrorValue(ValueType, unsigned int node_id);
00079
00080 virtual ValueType Activation(ValueType);
00081 virtual ValueType DActivation(ValueType);
00082
00084 itkSetMacro( Bias, ValueType );
00085 itkGetConstReferenceMacro( Bias, ValueType );
00087
00088 protected:
00089
00090 BackPropagationLayer();
00091 virtual ~BackPropagationLayer();
00092
00094 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00095
00096 private:
00097
00098 NodeVectorType m_NodeInputValues;
00099 NodeVectorType m_NodeOutputValues;
00100 NodeVectorType m_InputErrorValues;
00101 NodeVectorType m_OutputErrorValues;
00102 ValueType m_Bias;
00103 };
00104
00105 }
00106 }
00107
00108 #ifndef ITK_MANUAL_INSTANTIATION
00109 #include "itkBackPropagationLayer.txx"
00110 #endif
00111
00112 #endif
00113