ITK
4.1.0
Insight Segmentation and Registration Toolkit
|
00001 /*========================================================================= 00002 * 00003 * Copyright Insight Software Consortium 00004 * 00005 * Licensed under the Apache License, Version 2.0 (the "License"); 00006 * you may not use this file except in compliance with the License. 00007 * You may obtain a copy of the License at 00008 * 00009 * http://www.apache.org/licenses/LICENSE-2.0.txt 00010 * 00011 * Unless required by applicable law or agreed to in writing, software 00012 * distributed under the License is distributed on an "AS IS" BASIS, 00013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00014 * See the License for the specific language governing permissions and 00015 * limitations under the License. 00016 * 00017 *=========================================================================*/ 00018 #ifndef __itkRBFNetwork_h 00019 #define __itkRBFNetwork_h 00020 00021 #include "itkMultilayerNeuralNetworkBase.h" 00022 #include "itkBackPropagationLayer.h" 00023 #include "itkSigmoidTransferFunction.h" 00024 #include "itkLogSigmoidTransferFunction.h" 00025 #include "itkTanSigmoidTransferFunction.h" 00026 #include "itkHardLimitTransferFunction.h" 00027 #include "itkSignedHardLimitTransferFunction.h" 00028 #include "itkGaussianTransferFunction.h" 00029 #include "itkIdentityTransferFunction.h" 00030 #include "itkSumInputFunction.h" 00031 00032 #include "itkSymmetricSigmoidTransferFunction.h" 00033 #include "itkTanHTransferFunction.h" 00034 #include "itkRBFLayer.h" 00035 00036 namespace itk 00037 { 00038 namespace Statistics 00039 { 00046 template<class TMeasurementVector, class TTargetVector> 00047 class RBFNetwork : 00048 public MultilayerNeuralNetworkBase<TMeasurementVector, TTargetVector, BackPropagationLayer<TMeasurementVector, TTargetVector> > 00049 { 00050 public: 00051 typedef RBFNetwork Self; 00052 typedef MultilayerNeuralNetworkBase<TMeasurementVector, TTargetVector , BackPropagationLayer<TMeasurementVector, TTargetVector> > 00053 Superclass; 00054 typedef SmartPointer<Self> Pointer; 00055 typedef SmartPointer<const Self> ConstPointer; 00056 00057 typedef typename Superclass::ValueType ValueType; 00058 typedef typename Superclass::MeasurementVectorType MeasurementVectorType; 00059 typedef typename Superclass::TargetVectorType TargetVectorType; 00060 typedef typename Superclass::NetworkOutputType NetworkOutputType; 00061 00062 typedef typename Superclass::LayerInterfaceType LayerInterfaceType; 00063 typedef typename Superclass::LearningLayerType LearningLayerType; 00064 00065 typedef typename Superclass::WeightVectorType WeightVectorType; 00066 typedef typename Superclass::LayerVectorType LayerVectorType; 00067 00068 typedef typename Superclass::TransferFunctionInterfaceType TransferFunctionInterfaceType; 00069 typedef typename Superclass::InputFunctionInterfaceType InputFunctionInterfaceType; 00070 00071 // Specializations for RBF Networks 00072 typedef Array<ValueType> ArrayType; 00073 typedef EuclideanDistanceMetric<ArrayType> DistanceMetricType; 00074 typedef RadialBasisFunctionBase<ValueType> RBFTransferFunctionType; 00075 typedef RBFLayer<TMeasurementVector, TTargetVector> HiddenLayerType; 00076 00077 itkSetMacro(Classes, unsigned int); 00078 itkGetConstReferenceMacro(Classes, unsigned int); 00079 void SetCenter(TMeasurementVector c); 00080 void SetRadius(ValueType r); 00081 void SetDistanceMetric(DistanceMetricType* f); 00082 void InitializeWeights(); 00083 00085 itkTypeMacro(RBFNetwork, 00086 MultilayerNeuralNetworkBase); 00087 itkNewMacro(Self); 00089 00090 //Add the layers to the network. 00091 // 1 input, 1 hidden, 1 output 00092 void Initialize(); 00093 00094 itkSetMacro(NumOfInputNodes, unsigned int); 00095 itkGetConstReferenceMacro(NumOfInputNodes, unsigned int); 00096 00097 itkSetMacro(NumOfFirstHiddenNodes, unsigned int); 00098 itkGetConstReferenceMacro(NumOfFirstHiddenNodes, unsigned int); 00099 00100 itkSetMacro(NumOfOutputNodes, unsigned int); 00101 itkGetConstReferenceMacro(NumOfOutputNodes, unsigned int); 00102 00103 itkSetMacro(FirstHiddenLayerBias, ValueType); 00104 itkGetConstReferenceMacro(FirstHiddenLayerBias, ValueType); 00105 00106 //#define __USE_OLD_INTERFACE Comment out to ensure that new interface works 00107 #ifdef __USE_OLD_INTERFACE 00108 //Original Function name before consistency naming changes 00109 inline void SetNumOfHiddenNodes(const unsigned int & x) { SetNumOfFirstHiddenNodes(x); } 00110 inline unsigned int GetNumOfHiddenNodes(void) const { return GetNumOfFirstHiddenNodes(); } 00111 inline void SetHiddenLayerBias(const ValueType & bias) { SetFirstHiddenLayerBias(bias); } 00112 ValueType GetHiddenLayerBias(void) const { return GetFirstHiddenLayerBias();} 00113 #endif 00114 itkSetMacro(OutputLayerBias, ValueType); 00115 itkGetConstReferenceMacro(OutputLayerBias, ValueType); 00116 00117 virtual NetworkOutputType GenerateOutput(TMeasurementVector samplevector); 00118 00119 void SetInputFunction(InputFunctionInterfaceType* f); 00120 void SetInputTransferFunction(TransferFunctionInterfaceType* f); 00121 #ifdef __USE_OLD_INTERFACE 00122 //Original Function name before consistency naming changes 00123 inline void SetHiddenTransferFunction(TransferFunctionInterfaceType* f) { SetFirstHiddenTransferFunction (f); } 00124 #endif 00125 void SetFirstHiddenTransferFunction(TransferFunctionInterfaceType* f); 00126 void SetOutputTransferFunction(TransferFunctionInterfaceType* f); 00127 protected: 00128 00129 RBFNetwork(); 00130 virtual ~RBFNetwork(){}; 00131 00133 virtual void PrintSelf( std::ostream& os, Indent indent ) const; 00134 00135 private: 00136 00137 typename DistanceMetricType::Pointer m_DistanceMetric; 00138 std::vector<TMeasurementVector> m_Centers; // ui....uc 00139 std::vector<double> m_Radii; 00140 00141 unsigned int m_Classes; 00142 unsigned int m_NumOfInputNodes; 00143 unsigned int m_NumOfFirstHiddenNodes; 00144 unsigned int m_NumOfOutputNodes; 00145 00146 ValueType m_FirstHiddenLayerBias; 00147 ValueType m_OutputLayerBias; 00148 00149 typename InputFunctionInterfaceType::Pointer m_InputFunction; 00150 typename TransferFunctionInterfaceType::Pointer m_InputTransferFunction; 00151 typename RBFTransferFunctionType::Pointer m_FirstHiddenTransferFunction; 00152 typename TransferFunctionInterfaceType::Pointer m_OutputTransferFunction; 00153 }; 00154 00155 } // end namespace Statistics 00156 } // end namespace itk 00157 00158 #ifndef ITK_MANUAL_INSTANTIATION 00159 #include "itkRBFNetwork.hxx" 00160 #endif 00161 00162 #endif 00163