ITK  4.0.0
Insight Segmentation and Registration Toolkit
itkRBFNetwork.h
Go to the documentation of this file.
00001 /*=========================================================================
00002  *
00003  *  Copyright Insight Software Consortium
00004  *
00005  *  Licensed under the Apache License, Version 2.0 (the "License");
00006  *  you may not use this file except in compliance with the License.
00007  *  You may obtain a copy of the License at
00008  *
00009  *         http://www.apache.org/licenses/LICENSE-2.0.txt
00010  *
00011  *  Unless required by applicable law or agreed to in writing, software
00012  *  distributed under the License is distributed on an "AS IS" BASIS,
00013  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00014  *  See the License for the specific language governing permissions and
00015  *  limitations under the License.
00016  *
00017  *=========================================================================*/
00018 #ifndef __itkRBFNetwork_h
00019 #define __itkRBFNetwork_h
00020 
00021 #include "itkMultilayerNeuralNetworkBase.h"
00022 #include "itkBackPropagationLayer.h"
00023 #include "itkSigmoidTransferFunction.h"
00024 #include "itkLogSigmoidTransferFunction.h"
00025 #include "itkTanSigmoidTransferFunction.h"
00026 #include "itkHardLimitTransferFunction.h"
00027 #include "itkSignedHardLimitTransferFunction.h"
00028 #include "itkGaussianTransferFunction.h"
00029 #include "itkIdentityTransferFunction.h"
00030 #include "itkSumInputFunction.h"
00031 
00032 #include "itkSymmetricSigmoidTransferFunction.h"
00033 #include "itkTanHTransferFunction.h"
00034 #include "itkRBFLayer.h"
00035 
00036 namespace itk
00037 {
00038 namespace Statistics
00039 {
00046 template<class TMeasurementVector, class TTargetVector>
00047 class RBFNetwork :
00048     public MultilayerNeuralNetworkBase<TMeasurementVector, TTargetVector, BackPropagationLayer<TMeasurementVector, TTargetVector> >
00049 {
00050 public:
00051   typedef RBFNetwork               Self;
00052   typedef MultilayerNeuralNetworkBase<TMeasurementVector, TTargetVector , BackPropagationLayer<TMeasurementVector, TTargetVector> >
00053                                    Superclass;
00054   typedef SmartPointer<Self>       Pointer;
00055   typedef SmartPointer<const Self> ConstPointer;
00056 
00057   typedef typename Superclass::ValueType             ValueType;
00058   typedef typename Superclass::MeasurementVectorType MeasurementVectorType;
00059   typedef typename Superclass::TargetVectorType      TargetVectorType;
00060   typedef typename Superclass::NetworkOutputType     NetworkOutputType;
00061 
00062   typedef typename Superclass::LayerInterfaceType LayerInterfaceType;
00063   typedef typename Superclass::LearningLayerType  LearningLayerType;
00064 
00065   typedef typename Superclass::WeightVectorType WeightVectorType;
00066   typedef typename Superclass::LayerVectorType  LayerVectorType;
00067 
00068   typedef typename Superclass::TransferFunctionInterfaceType TransferFunctionInterfaceType;
00069   typedef typename Superclass::InputFunctionInterfaceType    InputFunctionInterfaceType;
00070 
00071   // Specializations for RBF Networks
00072   typedef Array<ValueType>                            ArrayType;
00073   typedef EuclideanDistanceMetric<ArrayType>          DistanceMetricType;
00074   typedef RadialBasisFunctionBase<ValueType>          RBFTransferFunctionType;
00075   typedef RBFLayer<TMeasurementVector, TTargetVector> HiddenLayerType;
00076 
00077   itkSetMacro(Classes, unsigned int);
00078   itkGetConstReferenceMacro(Classes, unsigned int);
00079   void SetCenter(TMeasurementVector c);
00080   void SetRadius(ValueType r);
00081   void SetDistanceMetric(DistanceMetricType* f);
00082   void InitializeWeights();
00083 
00085   itkTypeMacro(RBFNetwork,
00086                MultilayerNeuralNetworkBase);
00087   itkNewMacro(Self);
00089 
00090   //Add the layers to the network.
00091   // 1 input, 1 hidden, 1 output
00092   void Initialize();
00093 
00094   itkSetMacro(NumOfInputNodes, unsigned int);
00095   itkGetConstReferenceMacro(NumOfInputNodes, unsigned int);
00096 
00097   itkSetMacro(NumOfFirstHiddenNodes, unsigned int);
00098   itkGetConstReferenceMacro(NumOfFirstHiddenNodes, unsigned int);
00099 
00100   itkSetMacro(NumOfOutputNodes, unsigned int);
00101   itkGetConstReferenceMacro(NumOfOutputNodes, unsigned int);
00102 
00103   itkSetMacro(FirstHiddenLayerBias, ValueType);
00104   itkGetConstReferenceMacro(FirstHiddenLayerBias, ValueType);
00105 
00106   //#define __USE_OLD_INTERFACE  Comment out to ensure that new interface works
00107 #ifdef __USE_OLD_INTERFACE
00108   //Original Function name before consistency naming changes
00109   inline void SetNumOfHiddenNodes(const unsigned int & x) { SetNumOfFirstHiddenNodes(x); }
00110   inline unsigned int GetNumOfHiddenNodes(void) const { return GetNumOfFirstHiddenNodes(); }
00111   inline void SetHiddenLayerBias(const ValueType & bias) { SetFirstHiddenLayerBias(bias); }
00112   ValueType GetHiddenLayerBias(void) const { return GetFirstHiddenLayerBias();}
00113 #endif
00114   itkSetMacro(OutputLayerBias, ValueType);
00115   itkGetConstReferenceMacro(OutputLayerBias, ValueType);
00116 
00117   virtual NetworkOutputType GenerateOutput(TMeasurementVector samplevector);
00118 
00119   void SetInputFunction(InputFunctionInterfaceType* f);
00120   void SetInputTransferFunction(TransferFunctionInterfaceType* f);
00121 #ifdef __USE_OLD_INTERFACE
00122   //Original Function name before consistency naming changes
00123   inline void SetHiddenTransferFunction(TransferFunctionInterfaceType* f) { SetFirstHiddenTransferFunction (f); }
00124 #endif
00125   void SetFirstHiddenTransferFunction(TransferFunctionInterfaceType* f);
00126   void SetOutputTransferFunction(TransferFunctionInterfaceType* f);
00127 protected:
00128 
00129   RBFNetwork();
00130   virtual ~RBFNetwork(){};
00131 
00133   virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00134 
00135 private:
00136 
00137   typename DistanceMetricType::Pointer            m_DistanceMetric;
00138   std::vector<TMeasurementVector>                 m_Centers;  // ui....uc
00139   std::vector<double>                             m_Radii;
00140 
00141   unsigned int m_Classes;
00142   unsigned int m_NumOfInputNodes;
00143   unsigned int m_NumOfFirstHiddenNodes;
00144   unsigned int m_NumOfOutputNodes;
00145 
00146   ValueType m_FirstHiddenLayerBias;
00147   ValueType m_OutputLayerBias;
00148 
00149   typename InputFunctionInterfaceType::Pointer    m_InputFunction;
00150   typename TransferFunctionInterfaceType::Pointer m_InputTransferFunction;
00151   typename RBFTransferFunctionType::Pointer       m_FirstHiddenTransferFunction;
00152   typename TransferFunctionInterfaceType::Pointer m_OutputTransferFunction;
00153 };
00154 
00155 } // end namespace Statistics
00156 } // end namespace itk
00157 
00158 #ifndef ITK_MANUAL_INSTANTIATION
00159 #include "itkRBFNetwork.hxx"
00160 #endif
00161 
00162 #endif
00163