00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkLayerBase_h
00018 #define __itkLayerBase_h
00019
00020 #include <iostream>
00021 #include "itkLightProcessObject.h"
00022 #include "itkWeightSetBase.h"
00023 #include "itkArray.h"
00024 #include "itkVector.h"
00025 #include "itkTransferFunctionBase.h"
00026 #include "itkInputFunctionBase.h"
00027
00028 #include "itkMacro.h"
00029
00030 namespace itk
00031 {
00032 namespace Statistics
00033 {
00034
00035 template<class TVector, class TOutput>
00036 class LayerBase : public LightProcessObject
00037 {
00038
00039 public:
00040 typedef LayerBase Self;
00041 typedef LightProcessObject Superclass;
00042 typedef SmartPointer<Self> Pointer;
00043 typedef SmartPointer<const Self> ConstPointer;
00044
00046 itkTypeMacro(LayerBase, LightProcessObject);
00047
00048 typedef TVector InputVectorType;
00049 typedef TOutput OutputVectorType;
00050
00051 typedef typename TVector::ValueType ValueType;
00052 typedef ValueType* ValuePointer;
00053 typedef const ValueType* ValueConstPointer;
00054 typedef vnl_vector<ValueType> NodeVectorType;
00055 typedef Array<ValueType> InternalVectorType;
00056
00057 typedef WeightSetBase<TVector,TOutput> WeightSetType;
00058
00059 typedef TransferFunctionBase<ValueType> TransferFunctionType;
00060
00061 typedef InputFunctionBase<ValueType*, ValueType> InputFunctionType;
00062
00063 typedef typename InputFunctionType::Pointer InputFunctionPointer;
00064 typedef typename InputFunctionType::ConstPointer InputFunctionConstPointer;
00065
00066 typedef typename TransferFunctionType::Pointer TransferFunctionPointer;
00067 typedef typename TransferFunctionType::ConstPointer TransferFunctionConstPointer;
00068
00069 typedef typename WeightSetType::Pointer WeightSetPointer;
00070 typedef typename WeightSetType::ConstPointer WeightSetConstPointer;
00071
00072 virtual void SetNumberOfNodes(unsigned int);
00073 unsigned int GetNumberOfNodes() const;
00074
00075 virtual ValueType GetInputValue(unsigned int) const = 0;
00076 virtual ValueType GetOutputValue(unsigned int) const = 0;
00077 virtual ValuePointer GetOutputVector() = 0;
00078
00079 virtual void ForwardPropagate(){};
00080
00081 virtual void ForwardPropagate(TVector){};
00082
00083 virtual void BackwardPropagate(InternalVectorType){};
00084
00085 virtual void BackwardPropagate(){};
00086 virtual ValueType GetOutputErrorValue(unsigned int) const = 0;
00087 virtual void SetOutputErrorValues(TOutput) {};
00088
00089 virtual ValueType GetInputErrorValue(unsigned int) const = 0;
00090 virtual ValuePointer GetInputErrorVector() = 0;
00091 virtual void SetInputErrorValue(ValueType, unsigned int) {};
00092
00093
00094 void SetInputWeightSet(WeightSetType*);
00095 itkGetObjectMacro(InputWeightSet, WeightSetType);
00096 itkGetConstObjectMacro(InputWeightSet, WeightSetType);
00097
00098
00099 void SetOutputWeightSet(WeightSetType*);
00100 itkGetObjectMacro(OutputWeightSet, WeightSetType);
00101 itkGetConstObjectMacro(OutputWeightSet, WeightSetType);
00102
00103 void SetNodeInputFunction(InputFunctionType* f);
00104 itkGetObjectMacro(NodeInputFunction, InputFunctionType);
00105 itkGetConstObjectMacro(NodeInputFunction, InputFunctionType);
00106
00107 void SetTransferFunction(TransferFunctionType* f);
00108 itkGetObjectMacro(ActivationFunction, TransferFunctionType);
00109 itkGetConstObjectMacro(ActivationFunction, TransferFunctionType);
00110
00111 virtual ValueType Activation(ValueType) = 0;
00112 virtual ValueType DActivation(ValueType) = 0;
00113
00114 itkSetMacro(LayerType, unsigned int);
00115 itkGetConstReferenceMacro(LayerType, unsigned int);
00116
00117 itkSetMacro(LayerId,unsigned int);
00118 itkGetConstReferenceMacro(LayerId,unsigned int);
00119
00120 virtual void SetBias(const ValueType) = 0;
00121 virtual const ValueType & GetBias() const = 0;
00122
00123
00124 protected:
00125
00126 LayerBase();
00127 ~LayerBase();
00128
00130 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00131
00132 unsigned int m_LayerType;
00133 unsigned int m_LayerId;
00134 unsigned int m_NumberOfNodes;
00135
00136 typename WeightSetType::Pointer m_InputWeightSet;
00137 typename WeightSetType::Pointer m_OutputWeightSet;
00138
00139 TransferFunctionPointer m_ActivationFunction;
00140 InputFunctionPointer m_NodeInputFunction;
00141
00142
00143
00144 };
00145
00146 }
00147 }
00148
00149 #ifndef ITK_MANUAL_INSTANTIATION
00150 #include "itkLayerBase.txx"
00151 #endif
00152
00153 #endif
00154