itkRBFLayer.h
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkRBFLayer_h
00018 #define __itkRBFLayer_h
00019
00020 #include "itkCompletelyConnectedWeightSet.h"
00021 #include "itkLayerBase.h"
00022 #include "itkObject.h"
00023 #include "itkMacro.h"
00024 #include "itkRadialBasisFunctionBase.h"
00025 #include "itkEuclideanDistance.h"
00026
00027 namespace itk
00028 {
00029 namespace Statistics
00030 {
00031 template<class TMeasurementVector, class TTargetVector>
00032 class RBFLayer : public LayerBase<TMeasurementVector, TTargetVector>
00033 {
00034 public:
00035 typedef RBFLayer Self;
00036 typedef LayerBase<TMeasurementVector, TTargetVector> Superclass;
00037 typedef SmartPointer<Self> Pointer;
00038 typedef SmartPointer<const Self> ConstPointer;
00039
00041 itkTypeMacro(RBFLayer, LayerBase);
00042 itkNewMacro(Self);
00044
00045 typedef typename Superclass::ValueType ValueType;
00046 typedef typename Superclass::ValuePointer ValuePointer;
00047 typedef vnl_vector<ValueType> NodeVectorType;
00048 typedef typename Superclass::InternalVectorType InternalVectorType;
00049 typedef typename Superclass::OutputVectorType OutputVectorType;
00050 typedef typename Superclass::LayerInterfaceType LayerInterfaceType;
00051 typedef CompletelyConnectedWeightSet<TMeasurementVector,TTargetVector>
00052 WeightSetType;
00053
00054 typedef typename Superclass::WeightSetInterfaceType WeightSetInterfaceType;
00055 typedef typename Superclass::InputFunctionInterfaceType InputFunctionInterfaceType;
00056 typedef typename Superclass::TransferFunctionInterfaceType TransferFunctionInterfaceType;
00057
00058
00059 typedef EuclideanDistance<InternalVectorType> DistanceMetricType;
00060 typedef typename DistanceMetricType::Pointer DistanceMetricPointer;
00061 typedef RadialBasisFunctionBase<ValueType> RBFType;
00062
00063
00064 itkGetConstReferenceMacro(RBF_Dim, unsigned int);
00065 void SetRBF_Dim(unsigned int size);
00066 virtual void SetNumberOfNodes(unsigned int numNodes);
00067 virtual ValueType GetInputValue(unsigned int i) const;
00068 void SetInputValue(unsigned int i, ValueType value);
00069
00070 virtual ValueType GetOutputValue(unsigned int) const;
00071 virtual void SetOutputValue(unsigned int, ValueType);
00072
00073 virtual ValueType * GetOutputVector();
00074 void SetOutputVector(TMeasurementVector value);
00075
00076 virtual void ForwardPropagate();
00077 virtual void ForwardPropagate(TMeasurementVector input);
00078
00079 virtual void BackwardPropagate();
00080 virtual void BackwardPropagate(TTargetVector itkNotUsed(errors)){};
00081
00082 virtual void SetOutputErrorValues(TTargetVector);
00083 virtual ValueType GetOutputErrorValue(unsigned int node_id) const;
00084
00085 virtual ValueType GetInputErrorValue(unsigned int node_id) const;
00086 virtual ValueType * GetInputErrorVector();
00087 virtual void SetInputErrorValue(ValueType, unsigned int node_id);
00088
00089
00090 InternalVectorType GetCenter(unsigned int i) const;
00091 void SetCenter(TMeasurementVector c,unsigned int i);
00092
00093 ValueType GetRadii(unsigned int i) const;
00094 void SetRadii(ValueType c,unsigned int i);
00095
00096 virtual ValueType Activation(ValueType);
00097 virtual ValueType DActivation(ValueType);
00098
00100 itkSetMacro( Bias, ValueType );
00101 itkGetConstReferenceMacro( Bias, ValueType );
00103
00104 void SetDistanceMetric(DistanceMetricType* f);
00105 itkGetObjectMacro( DistanceMetric, DistanceMetricType );
00106 itkGetConstObjectMacro( DistanceMetric, DistanceMetricType );
00107
00108 itkSetMacro(NumClasses,unsigned int);
00109 itkGetConstReferenceMacro(NumClasses,unsigned int);
00110
00111 void SetRBF(RBFType* f);
00112 itkGetObjectMacro(RBF, RBFType);
00113 itkGetConstObjectMacro(RBF, RBFType);
00114
00115 protected:
00116
00117 RBFLayer();
00118 virtual ~RBFLayer();
00119
00121 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00122
00123 private:
00124
00125 NodeVectorType m_NodeInputValues;
00126 NodeVectorType m_NodeOutputValues;
00127 NodeVectorType m_InputErrorValues;
00128 NodeVectorType m_OutputErrorValues;
00129
00130 typename DistanceMetricType::Pointer m_DistanceMetric;
00131
00132 std::vector<InternalVectorType> m_Centers;
00133 InternalVectorType m_Radii;
00134 unsigned int m_NumClasses;
00135 ValueType m_Bias;
00136 unsigned int m_RBF_Dim;
00137 typename RBFType::Pointer m_RBF;
00138 };
00139
00140 }
00141 }
00142
00143 #ifndef ITK_MANUAL_INSTANTIATION
00144 #include "itkRBFLayer.txx"
00145 #endif
00146
00147 #endif
00148