RBFLayer.h
Go to the documentation of this file.
1 /*!
2  *
3  *
4  * \brief Implements a radial basis function layer.
5  *
6  *
7  *
8  * \author O. Krause
9  * \date 2014
10  *
11  *
12  * \par Copyright 1995-2017 Shark Development Team
13  *
14  * <BR><HR>
15  * This file is part of Shark.
16  * <http://shark-ml.org/>
17  *
18  * Shark is free software: you can redistribute it and/or modify
19  * it under the terms of the GNU Lesser General Public License as published
20  * by the Free Software Foundation, either version 3 of the License, or
21  * (at your option) any later version.
22  *
23  * Shark is distributed in the hope that it will be useful,
24  * but WITHOUT ANY WARRANTY; without even the implied warranty of
25  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26  * GNU Lesser General Public License for more details.
27  *
28  * You should have received a copy of the GNU Lesser General Public License
29  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
30  *
31  */
32 #ifndef SHARK_MODELS_RBFLayer_H
33 #define SHARK_MODELS_RBFLayer_H
34 
35 #include <shark/Core/DLLSupport.h>
37 #include <boost/math/constants/constants.hpp>
38 namespace shark {
39 
40 /// \brief Implements a layer of radial basis functions in a neural network.
41 ///
42 /// A Radial basis function layer as modeled in shark is a set of N
43 /// Gaussian distributions \f$ p(x|i) \f$.
44 /// \f[
45 /// p(x|i) = e^{\gamma_i*\|x-m_i\|^2}
46 /// \f]
47 /// and the layer transforms an input x to a vector \f$(p(x|1),\dots,p(x|N)\f$.
48 /// The \f$\gamma_i\f$ govern the width of the Gaussians, while the
49 /// vectors \f$ m_i \f$ set the centers of every Gaussian distribution.
50 ///
51 /// RBF networks profit much from good guesses on the centers and
52 /// kernel function parameters. In case of a Gaussian kernel a call
53 /// to k-Means or the EM-algorithm can be used to get a good
54 /// initialisation for the network.
55 class RBFLayer : public AbstractModel<RealVector,RealVector>
56 {
57 private:
58  struct InternalState: public State{
59  RealMatrix norm2;
60 
61  void resize(std::size_t numPatterns, std::size_t numNeurons){
62  norm2.resize(numPatterns,numNeurons);
63  }
64  };
65 
66 public:
67  /// \brief Creates an empty Radial Basis Function layer.
69 
70  /// \brief Creates a layer of a Radial Basis Function Network.
71  ///
72  /// This method creates a Radial Basis Function Network (RBFN) with
73  /// \em numInput input neurons and \em numOutput output neurons.
74  ///
75  /// \param numInput Number of input neurons, equal to dimensionality of
76  /// input space.
77  /// \param numOutput Number of output neurons, equal to dimensionality of
78  /// output space and number of gaussian distributions
79  SHARK_EXPORT_SYMBOL RBFLayer(std::size_t numInput, std::size_t numOutput);
80 
81  /// \brief From INameable: return the class name.
82  std::string name() const
83  { return "RBFLayer"; }
84 
85  ///\brief Returns the current parameter vector. The amount and order of weights depend on the training parameters.
86  ///
87  ///The format of the parameter vector is \f$ (m_1,\dots,m_k,\log(\gamma_1),\dots,\log(\gamma_k))\f$
88  ///if training of one or more parameters is deactivated, they are removed from the parameter vector
89  SHARK_EXPORT_SYMBOL RealVector parameterVector()const;
90 
91  ///\brief Sets the new internal parameters.
92  SHARK_EXPORT_SYMBOL void setParameterVector(RealVector const& newParameters);
93 
94  ///\brief Returns the number of parameters which are currently enabled for training.
95  SHARK_EXPORT_SYMBOL std::size_t numberOfParameters()const;
96 
97  ///\brief Returns the number of input neurons.
98  Shape inputShape()const{
99  return m_centers.size2();
100  }
101 
102  ///\brief Returns the number of output neurons.
104  return m_centers.size1();
105  }
106 
107  boost::shared_ptr<State> createState()const{
108  return boost::shared_ptr<State>(new InternalState());
109  }
110 
111 
112  /// \brief Configures a Radial Basis Function Network.
113  ///
114  /// This method initializes the structure of the Radial Basis Function Network (RBFN) with
115  /// \em numInput input neurons, \em numOutput output neurons and \em numHidden
116  /// hidden neurons.
117  ///
118  /// \param numInput Number of input neurons, equal to dimensionality of
119  /// input space.
120  /// \param numOutput Number of output neurons (basis functions), equal to dimensionality of
121  /// output space.
122  SHARK_EXPORT_SYMBOL void setStructure(std::size_t numInput, std::size_t numOutput);
123 
124 
126  SHARK_EXPORT_SYMBOL void eval(BatchInputType const& patterns, BatchOutputType& outputs, State& state)const;
127 
128 
130  BatchInputType const& pattern, BatchOutputType const& outputs,
131  BatchOutputType const& coefficients, State const& state, RealVector& gradient
132  )const;
133 
134  ///\brief Enables or disables parameters for learning.
135  ///
136  /// \param centers whether the centers should be trained
137  /// \param width whether the distribution width should be trained
138  SHARK_EXPORT_SYMBOL void setTrainingParameters(bool centers, bool width);
139 
140  ///\brief Returns the center values of the neurons.
141  BatchInputType const& centers()const{
142  return m_centers;
143  }
144  ///\brief Sets the center values of the neurons.
146  return m_centers;
147  }
148 
149  ///\brief Returns the width parameter of the Gaussian functions
150  RealVector const& gamma()const{
151  return m_gamma;
152  }
153 
154  /// \brief sets the width parameters - the gamma values - of the distributions.
155  SHARK_EXPORT_SYMBOL void setGamma(RealVector const& gamma);
156 
157  /// From ISerializable, reads a model from an archive
158  SHARK_EXPORT_SYMBOL void read( InArchive & archive );
159 
160  /// From ISerializable, writes a model to an archive
161  SHARK_EXPORT_SYMBOL void write( OutArchive & archive ) const;
162 protected:
163  //====model parameters
164 
165  ///\brief The center points. The i-th element corresponds to the center of neuron number i
166  RealMatrix m_centers;
167 
168  ///\brief stores the width parameters of the Gaussian functions
169  RealVector m_gamma;
170 
171  /// \brief the logarithm of the normalization constant for every distribution
172  RealVector m_logNormalization;
173 
174  //=====training parameters
175  ///enables learning of the center points of the neurons
177  ///enables learning of the width parameters.
179 
180 
181 
182 };
183 }
184 
185 #endif
186