ModelKernel.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Kernel on a finite, discrete space.
6  *
7  *
8  *
9  * \author T. Glasmachers
10  * \date 2012
11  *
12  *
13  * \par Copyright 1995-2017 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://shark-ml.org/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_MODELS_KERNELS_MODEL_KERNEL_H
36 #define SHARK_MODELS_KERNELS_MODEL_KERNEL_H
37 
38 
40 #include <shark/LinAlg/Base.h>
42 #include <vector>
43 #include <boost/scoped_ptr.hpp>
44 
45 namespace shark {
46 
47 namespace detail{
48 template<class InputType, class IntermediateType>
49 class ModelKernelImpl : public AbstractKernelFunction<InputType>
50 {
51 private:
52  typedef AbstractKernelFunction<InputType> base_type;
53 public:
57  typedef AbstractKernelFunction<IntermediateType> Kernel;
58  typedef AbstractModel<InputType,IntermediateType> Model;
59 private:
60  struct InternalState: public State{
61  boost::shared_ptr<State> kernelStateX1X2;
62  boost::shared_ptr<State> kernelStateX2X1;
63  boost::shared_ptr<State> modelStateX1;
64  boost::shared_ptr<State> modelStateX2;
65  typename Model::BatchOutputType intermediateX1;
66  typename Model::BatchOutputType intermediateX2;
67  };
68 public:
69 
70  ModelKernelImpl(Kernel* kernel, Model* model):mpe_kernel(kernel),mpe_model(model){
71  if(kernel->hasFirstParameterDerivative()
72  && kernel->hasFirstInputDerivative()
73  && model->hasFirstParameterDerivative())
75  }
76 
77  /// \brief From INameable: return the class name.
78  std::string name() const
79  { return "ModelKernel"; }
80 
81  std::size_t numberOfParameters()const{
82  return mpe_kernel->numberOfParameters() + mpe_model->numberOfParameters();
83  }
84  RealVector parameterVector() const{
85  return mpe_kernel->parameterVector() | mpe_model->parameterVector();
86  }
87  void setParameterVector(RealVector const& newParameters){
88  SIZE_CHECK(newParameters.size() == numberOfParameters());
89  std::size_t kParams =mpe_kernel->numberOfParameters();
90  mpe_kernel->setParameterVector(subrange(newParameters,0,kParams));
91  mpe_model->setParameterVector(subrange(newParameters,kParams,newParameters.size()));
92  }
93 
94  boost::shared_ptr<State> createState()const{
95  InternalState* s = new InternalState();
96  boost::shared_ptr<State> sharedState(s);//create now to allow for destructor to be called in case of exception
97  s->kernelStateX1X2 = mpe_kernel->createState();
98  s->kernelStateX2X1 = mpe_kernel->createState();
99  s->modelStateX1 = mpe_model->createState();
100  s->modelStateX2 = mpe_model->createState();
101  return sharedState;
102  }
103 
104  double eval(ConstInputReference x1, ConstInputReference x2) const{
105  return mpe_kernel->eval((*mpe_model)(x1),(*mpe_model)(x2));
106  }
107 
108  void eval(ConstBatchInputReference x1, ConstBatchInputReference x2, RealMatrix& result, State& state) const{
109  InternalState& s=state.toState<InternalState>();
110  mpe_model->eval(x1,s.intermediateX1,*s.modelStateX1);
111  mpe_model->eval(x2,s.intermediateX2,*s.modelStateX2);
112  mpe_kernel->eval(s.intermediateX2,s.intermediateX1,result,*s.kernelStateX2X1);
113  mpe_kernel->eval(s.intermediateX1,s.intermediateX2,result,*s.kernelStateX1X2);
114 
115  }
116 
117  void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result) const{
118  return mpe_kernel->eval((*mpe_model)(batchX1),(*mpe_model)(batchX2),result);
119  }
120 
121  void weightedParameterDerivative(
122  ConstBatchInputReference batchX1,
123  ConstBatchInputReference batchX2,
124  RealMatrix const& coefficients,
125  State const& state,
126  RealVector& gradient
127  ) const{
128  gradient.resize(numberOfParameters());
129  InternalState const& s=state.toState<InternalState>();
130 
131  //compute derivative of the kernel wrt. parameters
132  RealVector kernelGrad;
133  mpe_kernel->weightedParameterDerivative(
134  s.intermediateX1,s.intermediateX2,
135  coefficients,*s.kernelStateX1X2,kernelGrad
136  );
137  //compute derivative of the kernel wrt left and right parameter
138  typename Model::BatchOutputType inputDerivativeX1, inputDerivativeX2;
139  mpe_kernel->weightedInputDerivative(
140  s.intermediateX1,s.intermediateX2,
141  coefficients,*s.kernelStateX1X2,inputDerivativeX1
142  );
143  mpe_kernel->weightedInputDerivative(
144  s.intermediateX2,s.intermediateX1,
145  trans(coefficients),*s.kernelStateX2X1,inputDerivativeX2
146  );
147 
148  //compute derivative of model wrt parameters
149  RealVector modelGradX1,modelGradX2;
150  mpe_model->weightedParameterDerivative(batchX1,s.intermediateX1, inputDerivativeX1,*s.modelStateX1,modelGradX1);
151  mpe_model->weightedParameterDerivative(batchX2,s.intermediateX2, inputDerivativeX2,*s.modelStateX2,modelGradX2);
152  noalias(gradient) = kernelGrad | (modelGradX1+modelGradX2);
153  }
154 
155  void read(InArchive& ar){
156  SHARK_RUNTIME_CHECK(mpe_kernel, "The kernel function is NULL, kernel needs to be constructed prior to read in");
157  SHARK_RUNTIME_CHECK(mpe_model, "The model is NULL, model needs to be constructed prior to read in");
158  ar >> *mpe_kernel;
159  ar >> *mpe_model;
160  }
161 
162  void write(OutArchive& ar) const{
163  ar << *mpe_kernel;
164  ar << *mpe_model;
165  }
166 
167 private:
168  Kernel* mpe_kernel;
169  Model* mpe_model;
170 };
171 }
172 
173 
174 /// \brief Kernel function that uses a Model as transformation function for another kernel
175 ///
176 /// Using an Abstractmodel \f$ f: X \rightarrow X' \f$ and an inner kernel
177 /// \f$k: X' \times X' \rightarrow \mathbb{R} \f$, this class defines another kernel
178 /// \f$K: X \times X \rightarrow \mathbb{R}\f$ using
179 /// \f[
180 /// K(x,y) = k(f(x),f(y))
181 ///\f]
182 /// If the inner kernel \f$k\f$ suports both input, as well as parameter derivative and
183 /// the model also supports the parameter derivative, the kernel \f$K\f$ also
184 /// supports the first parameter derivative using
185 /// \f[
186 /// \frac{\partial}{\partial \theta} K(x,y) =
187 /// \frac{\partial}{\partial f(x)} k(f(x),f(y))\frac{\partial}{\partial \theta} f(x)
188 /// +\frac{\partial}{\partial f(y)} k(f(x),f(y))\frac{\partial}{\partial \theta} f(y)
189 ///\f]
190 /// This requires the derivative of the inputs of the kernel wrt both parameters which,
191 /// by limitation of the current kernel interface, requires to compute \f$k(f(x),f(y))\f$ and \f$k(f(y),f(x))\f$.
192 template<class InputType=RealVector>
193 class ModelKernel: public AbstractKernelFunction<InputType>{
194 private:
195  typedef AbstractKernelFunction<InputType> base_type;
196 public:
200 
201  template<class IntermediateType>
205  ):m_wrapper(new detail::ModelKernelImpl<InputType,IntermediateType>(kernel,model)){
206  SHARK_RUNTIME_CHECK(kernel, "The kernel function is not allowed to be NULL");
207  SHARK_RUNTIME_CHECK(model, "The model is not allowed to be NULL");
208  if(m_wrapper->hasFirstParameterDerivative())
210  }
211 
212  /// \brief From INameable: return the class name.
213  std::string name() const
214  { return "ModelKernel"; }
215 
216  /// \brief Returns the number of parameters.
217  std::size_t numberOfParameters()const{
218  return m_wrapper->numberOfParameters();
219  }
220  ///\brief Returns the concatenated parameters of kernel and model.
221  RealVector parameterVector() const{
222  return m_wrapper->parameterVector();
223  }
224  ///\brief Sets the concatenated parameters of kernel and model.
225  void setParameterVector(RealVector const& newParameters){
226  m_wrapper->setParameterVector(newParameters);
227  }
228 
229  ///\brief Returns the internal state object used for eval and the derivatives.
230  boost::shared_ptr<State> createState()const{
231  return m_wrapper->createState();
232  }
233 
234  ///\brief Computes K(x,y) for a single input pair.
235  double eval(ConstInputReference x1, ConstInputReference x2) const{
236  return m_wrapper->eval(x1,x2);
237  }
238 
239  /// \brief For two batches X1 and X2 computes the matrix k_ij=K(X1_i,X2_j) and stores the state for the derivatives.
240  void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result, State& state) const{
241  return m_wrapper->eval(batchX1,batchX2,result,state);
242  }
243  /// \brief For two batches X1 and X2 computes the matrix k_ij=K(X1_i,X2_j).
244  void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result) const{
245  m_wrapper->eval(batchX1,batchX2,result);
246  }
247 
248  ///\brief After a call to eval with state, computes the derivative wrt all parameters of the kernel and the model.
249  ///
250  /// This is computed over the whole kernel matrix k_ij created by eval and summed up using the coefficients c
251  /// thus this call returns \f$ \sum_{i,j} c_{ij} \frac{\partial}{\partial \theta} k(x^1_i,x^2_j)\f$.
253  ConstBatchInputReference batchX1,
254  ConstBatchInputReference batchX2,
255  RealMatrix const& coefficients,
256  State const& state,
257  RealVector& gradient
258  ) const{
259  m_wrapper->weightedParameterDerivative(batchX1,batchX2,coefficients,state,gradient);
260  }
261 
262  ///\brief Stores the kernel to an Archive.
263  void write(OutArchive& ar) const{
264  ar<< *m_wrapper;
265  }
266  ///\brief Reads the kernel from an Archive.
267  void read(OutArchive& ar) const{
268  ar >> *m_wrapper;
269  }
270 private:
271  boost::scoped_ptr<AbstractKernelFunction<InputType> > m_wrapper;
272 };
273 
276 
277 
278 }
279 #endif