Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkRBFLayer_h
00018 #define __itkRBFLayer_h
00019
00020 #include "itkCompletelyConnectedWeightSet.h"
00021 #include "itkLayerBase.h"
00022 #include "itkObject.h"
00023 #include "itkMacro.h"
00024 #include "itkRadialBasisFunctionBase.h"
00025 #ifdef ITK_USE_REVIEW_STATISTICS
00026 #include "itkEuclideanDistanceMetric.h"
00027 #else
00028 #include "itkEuclideanDistance.h"
00029 #endif
00030
00031 namespace itk
00032 {
00033 namespace Statistics
00034 {
00035 template<class TMeasurementVector, class TTargetVector>
00036 class RBFLayer : public LayerBase<TMeasurementVector, TTargetVector>
00037 {
00038 public:
00039 typedef RBFLayer Self;
00040 typedef LayerBase<TMeasurementVector, TTargetVector> Superclass;
00041 typedef SmartPointer<Self> Pointer;
00042 typedef SmartPointer<const Self> ConstPointer;
00043
00045 itkTypeMacro(RBFLayer, LayerBase);
00046 itkNewMacro(Self);
00048
00049 typedef typename Superclass::ValueType ValueType;
00050 typedef typename Superclass::ValuePointer ValuePointer;
00051 typedef vnl_vector<ValueType> NodeVectorType;
00052 typedef typename Superclass::InternalVectorType InternalVectorType;
00053 typedef typename Superclass::OutputVectorType OutputVectorType;
00054 typedef typename Superclass::LayerInterfaceType LayerInterfaceType;
00055 typedef CompletelyConnectedWeightSet<TMeasurementVector,TTargetVector>
00056 WeightSetType;
00057
00058 typedef typename Superclass::WeightSetInterfaceType WeightSetInterfaceType;
00059 typedef typename Superclass::InputFunctionInterfaceType InputFunctionInterfaceType;
00060 typedef typename Superclass::TransferFunctionInterfaceType TransferFunctionInterfaceType;
00061
00062
00063 #ifdef ITK_USE_REVIEW_STATISTICS
00064 typedef EuclideanDistanceMetric<InternalVectorType> DistanceMetricType;
00065 #else
00066 typedef EuclideanDistance<InternalVectorType> DistanceMetricType;
00067 #endif
00068 typedef typename DistanceMetricType::Pointer DistanceMetricPointer;
00069 typedef RadialBasisFunctionBase<ValueType> RBFType;
00070
00071
00072 itkGetConstReferenceMacro(RBF_Dim, unsigned int);
00073 void SetRBF_Dim(unsigned int size);
00074 virtual void SetNumberOfNodes(unsigned int numNodes);
00075 virtual ValueType GetInputValue(unsigned int i) const;
00076 void SetInputValue(unsigned int i, ValueType value);
00077
00078 virtual ValueType GetOutputValue(unsigned int) const;
00079 virtual void SetOutputValue(unsigned int, ValueType);
00080
00081 virtual ValueType * GetOutputVector();
00082 void SetOutputVector(TMeasurementVector value);
00083
00084 virtual void ForwardPropagate();
00085 virtual void ForwardPropagate(TMeasurementVector input);
00086
00087 virtual void BackwardPropagate();
00088 virtual void BackwardPropagate(TTargetVector itkNotUsed(errors)){};
00089
00090 virtual void SetOutputErrorValues(TTargetVector);
00091 virtual ValueType GetOutputErrorValue(unsigned int node_id) const;
00092
00093 virtual ValueType GetInputErrorValue(unsigned int node_id) const;
00094 virtual ValueType * GetInputErrorVector();
00095 virtual void SetInputErrorValue(ValueType, unsigned int node_id);
00096
00097
00098 InternalVectorType GetCenter(unsigned int i) const;
00099 void SetCenter(TMeasurementVector c,unsigned int i);
00100
00101 ValueType GetRadii(unsigned int i) const;
00102 void SetRadii(ValueType c,unsigned int i);
00103
00104 virtual ValueType Activation(ValueType);
00105 virtual ValueType DActivation(ValueType);
00106
00108 itkSetMacro( Bias, ValueType );
00109 itkGetConstReferenceMacro( Bias, ValueType );
00111
00112 void SetDistanceMetric(DistanceMetricType* f);
00113 itkGetObjectMacro( DistanceMetric, DistanceMetricType );
00114 itkGetConstObjectMacro( DistanceMetric, DistanceMetricType );
00115
00116 itkSetMacro(NumClasses,unsigned int);
00117 itkGetConstReferenceMacro(NumClasses,unsigned int);
00118
00119 void SetRBF(RBFType* f);
00120 itkGetObjectMacro(RBF, RBFType);
00121 itkGetConstObjectMacro(RBF, RBFType);
00122
00123 protected:
00124
00125 RBFLayer();
00126 virtual ~RBFLayer();
00127
00129 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00130
00131 private:
00132
00133 NodeVectorType m_NodeInputValues;
00134 NodeVectorType m_NodeOutputValues;
00135 NodeVectorType m_InputErrorValues;
00136 NodeVectorType m_OutputErrorValues;
00137
00138 typename DistanceMetricType::Pointer m_DistanceMetric;
00139
00140 std::vector<InternalVectorType> m_Centers;
00141 InternalVectorType m_Radii;
00142 unsigned int m_NumClasses;
00143 ValueType m_Bias;
00144 unsigned int m_RBF_Dim;
00145 typename RBFType::Pointer m_RBF;
00146 };
00147
00148 }
00149 }
00150
00151 #ifndef ITK_MANUAL_INSTANTIATION
00152 #include "itkRBFLayer.txx"
00153 #endif
00154
00155 #endif
00156