00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkRBFLayerBase_h
00018 #define __itkRBFLayerBase_h
00019
00020 #include "itkLayerBase.h"
00021 #include "itkObject.h"
00022 #include "itkMacro.h"
00023 #include "itkRadialBasisFunctionBase.h"
00024 #include "itkEuclideanDistance.h"
00025
00026 namespace itk
00027 {
00028 namespace Statistics
00029 {
00030
00031 template<class TVector, class TOutput>
00032 class RBFLayer : public LayerBase<TVector, TOutput>
00033 {
00034 public:
00035
00036 typedef RBFLayer Self;
00037 typedef LayerBase<TVector, TOutput> Superclass;
00038 typedef SmartPointer<Self> Pointer;
00039 typedef SmartPointer<const Self> ConstPointer;
00040
00042 itkTypeMacro(RBFLayer, LayerBase);
00043 itkNewMacro(Self) ;
00045
00046 typedef typename Superclass::ValueType ValueType;
00047 typedef typename Superclass::ValuePointer ValuePointer;
00048 typedef vnl_vector<ValueType> NodeVectorType;
00049
00050 typedef typename Superclass::InternalVectorType InternalVectorType;
00051
00052 typedef typename Superclass::OutputVectorType OutputVectorType;
00053
00054 typedef RadialBasisFunctionBase<ValueType> RBFType;
00055
00056
00057 typedef EuclideanDistance<InternalVectorType> DistanceMetricType;
00058 typedef typename DistanceMetricType::Pointer DistanceMetricPointer;
00059
00060 void SetNumberOfNodes(unsigned int numNodes);
00061
00062 itkGetConstReferenceMacro(RBF_Dim, unsigned int);
00063 void SetRBF_Dim(unsigned int size);
00064
00065
00066 ValueType GetInputValue(unsigned int i) const;
00067 void SetInputValue(unsigned int i,ValueType value);
00068
00069 itkGetConstReferenceMacro(LayerType, unsigned int);
00070
00071 ValueType GetOutputValue(unsigned int) const;
00072 void SetOutputValue(unsigned int, ValueType);
00073
00074 ValuePointer GetOutputVector();
00075 void SetOutputVector(TVector value);
00076
00077 void ForwardPropagate();
00078 void ForwardPropagate(TVector input);
00079
00080 void BackwardPropagate();
00081 void BackwardPropagate(TOutput itkNotUsed(errors)){};
00082
00083 void SetOutputErrorValues(TOutput);
00084 ValueType GetOutputErrorValue(unsigned int node_id) const;
00085
00086
00087 ValueType GetInputErrorValue(unsigned int node_id) const;
00088 ValuePointer GetInputErrorVector();
00089 void SetInputErrorValue(ValueType, unsigned int node_id);
00090
00091
00092 InternalVectorType GetCenter(unsigned int i) const;
00093 void SetCenter(TVector c,unsigned int i);
00094
00095 ValueType GetRadii(unsigned int i) const;
00096 void SetRadii(ValueType c,unsigned int i);
00097
00098
00099 ValueType Activation(ValueType);
00100 ValueType DActivation(ValueType);
00101
00102 itkSetMacro( Bias, ValueType );
00103 itkGetConstReferenceMacro( Bias, ValueType );
00104
00105 void SetDistanceMetric(DistanceMetricType* f);
00106 itkGetObjectMacro( DistanceMetric, DistanceMetricType );
00107 itkGetConstObjectMacro( DistanceMetric, DistanceMetricType );
00108
00109 itkSetMacro(NumClasses,unsigned int);
00110 itkGetConstReferenceMacro(NumClasses,unsigned int);
00111
00112 void SetRBF(RBFType* f);
00113 itkGetObjectMacro(RBF, RBFType);
00114 itkGetConstObjectMacro(RBF, RBFType);
00115
00116 protected:
00117
00118 RBFLayer();
00119 ~RBFLayer();
00120
00122 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00123
00124 private:
00125
00126 typename DistanceMetricType::Pointer m_DistanceMetric;
00127 NodeVectorType m_NodeInputValues;
00128 NodeVectorType m_NodeOutputValues;
00129 NodeVectorType m_InputErrorValues;
00130 NodeVectorType m_OutputErrorValues;
00131
00132 std::vector<InternalVectorType> m_Centers;
00133 InternalVectorType m_Radii;
00134 unsigned int m_NumClasses;
00135 ValueType m_Bias;
00136 unsigned int m_RBF_Dim;
00137 typename RBFType::Pointer m_RBF;
00138 };
00139
00140 }
00141 }
00142
00143 #ifndef ITK_MANUAL_INSTANTIATION
00144 #include "itkRBFLayer.txx"
00145 #endif
00146
00147 #endif
00148