00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018 #ifndef __itkRBFNetwork_h
00019 #define __itkRBFNetwork_h
00020
00021
00022 #include "itkMultilayerNeuralNetworkBase.h"
00023 #include "itkBackPropagationLayer.h"
00024 #include "itkRBFLayer.h"
00025 #include "itkCompletelyConnectedWeightSet.h"
00026 #include "itkSigmoidTransferFunction.h"
00027 #include "itkLogSigmoidTransferFunction.h"
00028 #include "itkSymmetricSigmoidTransferFunction.h"
00029 #include "itkTanSigmoidTransferFunction.h"
00030 #include "itkHardLimitTransferFunction.h"
00031 #include "itkSignedHardLimitTransferFunction.h"
00032 #include "itkGaussianTransferFunction.h"
00033 #include "itkTanHTransferFunction.h"
00034 #include "itkIdentityTransferFunction.h"
00035 #include "itkSumInputFunction.h"
00036 #include "itkEuclideanDistance.h"
00037
00038 namespace itk
00039 {
00040 namespace Statistics
00041 {
00042 template<class TVector, class TOutput>
00043 class RBFNetwork : public MultilayerNeuralNetworkBase<TVector, TOutput>
00044 {
00045 public:
00046
00047 typedef RBFNetwork Self;
00048 typedef MultilayerNeuralNetworkBase<TVector, TOutput> Superclass;
00049 typedef SmartPointer<Self> Pointer;
00050 typedef SmartPointer<const Self> ConstPointer;
00051 typedef typename Superclass::ValueType ValueType;
00052 typedef Array<ValueType> ArrayType;
00053 typedef TransferFunctionBase<ValueType> TransferFunctionType;
00054 typedef RadialBasisFunctionBase<ValueType> RBFType;
00055 typedef InputFunctionBase<ValueType*, ValueType> InputFunctionType;
00056 typedef EuclideanDistance<ArrayType> DistanceMetricType;
00057
00058 typename InputFunctionType::Pointer InputFunction;
00059 typename DistanceMetricType::Pointer DistanceMetric;
00060
00061 typename TransferFunctionType::Pointer InputTransferFunction;
00062 typename RBFType::Pointer HiddenTransferFunction;
00063 typename TransferFunctionType::Pointer OutputTransferFunction;
00064
00065 typedef typename Superclass::NetworkOutputType NetworkOutputType;
00066
00067
00068 itkTypeMacro(RBFNetwork,
00069 MultilayerNeuralNetworkBase);
00070 itkNewMacro(Self) ;
00071
00072
00073
00074 void Initialize();
00075
00076 itkSetMacro(NumOfInputNodes,unsigned int);
00077 itkGetConstReferenceMacro(NumOfInputNodes,unsigned int);
00078
00079 itkSetMacro(NumOfHiddenNodes,unsigned int);
00080 itkGetConstReferenceMacro(NumOfHiddenNodes, unsigned int);
00081
00082 itkSetMacro(NumOfOutputNodes,unsigned int);
00083 itkGetConstReferenceMacro(NumOfOutputNodes, unsigned int);
00084
00085 itkSetMacro(HiddenLayerBias, ValueType);
00086 itkGetConstReferenceMacro(HiddenLayerBias, ValueType);
00087
00088 itkSetMacro(OutputLayerBias, ValueType);
00089 itkGetConstReferenceMacro(OutputLayerBias, ValueType);
00090
00091 itkSetMacro(Classes,unsigned int);
00092 itkGetConstReferenceMacro(Classes,unsigned int);
00093
00094
00095 virtual NetworkOutputType GenerateOutput(TVector samplevector);
00096
00097 void SetInputTransferFunction(TransferFunctionType* f);
00098 void SetDistanceMetric(DistanceMetricType* f);
00099 void SetHiddenTransferFunction(TransferFunctionType* f);
00100 void SetOutputTransferFunction(TransferFunctionType* f);
00101
00102 void SetInputFunction(InputFunctionType* f);
00103 void InitializeWeights();
00104
00105 void SetCenter(TVector c);
00106 void SetRadius(ValueType r);
00107
00108 protected:
00109
00110 RBFNetwork();
00111 ~RBFNetwork(){};
00112
00114 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00115
00116 private:
00117
00118 unsigned int m_NumOfInputNodes;
00119 unsigned int m_NumOfHiddenNodes;
00120 unsigned int m_NumOfOutputNodes;
00121 unsigned int m_Classes;
00122 ValueType m_HiddenLayerBias;
00123 ValueType m_OutputLayerBias;
00124 std::vector<TVector> m_Centers;
00125 std::vector<double> m_Radii;
00126 };
00127
00128 }
00129 }
00130
00131 #ifndef ITK_MANUAL_INSTANTIATION
00132 #include "itkRBFNetwork.txx"
00133 #endif
00134
00135 #endif
00136