AbstractKernelFunction.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief abstract super class of all kernel functions
6  *
7  *
8  *
9  * \author T.Glasmachers, O. Krause, M. Tuma
10  * \date 2010-2012
11  *
12  *
13  * \par Copyright 1995-2017 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://shark-ml.org/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_MODELS_KERNELS_ABSTRACTKERNELFUNCTION_H
36 #define SHARK_MODELS_KERNELS_ABSTRACTKERNELFUNCTION_H
37 
39 #include <shark/LinAlg/Base.h>
40 #include <shark/Core/Flags.h>
41 #include <shark/Core/State.h>
42 namespace shark {
43 
44 #ifdef SHARK_COUNT_KERNEL_LOOKUPS
45  #define INCREMENT_KERNEL_COUNTER( counter ) { counter++; }
46 #else
47  #define INCREMENT_KERNEL_COUNTER( counter ) { }
48 #endif
49 
50 /// \brief Base class of all Kernel functions.
51 ///
52 /// \par
53 /// A (Mercer) kernel is a symmetric positive definite
54 /// function of two parameters. It is (currently) used
55 /// in two contexts in Shark, namely for kernel methods
56 /// such as support vector machines (SVMs), and for
57 /// radial basis function networks.
58 ///
59 /// \par
60 /// In Shark a kernel function class represents a parametric
61 /// family of such kernel functions: The AbstractKernelFunction
62 /// interface inherits the IParameterizable interface.
63 ///
64 template<class InputTypeT>
65 class AbstractKernelFunction : public AbstractMetric<InputTypeT>
66 {
67 private:
69  typedef Batch<InputTypeT> Traits;
70 public:
71  /// \brief Input type of the Kernel.
72  typedef typename base_type::InputType InputType;
73  /// \brief batch input type of the kernel
75  /// \brief Const references to InputType
77  /// \brief Const references to BatchInputType
79 
81 
82  /// enumerations of kerneland metric features (flags)
83  enum Feature {
84  HAS_FIRST_PARAMETER_DERIVATIVE = 1, ///< is the kernel differentiable w.r.t. its parameters?
85  HAS_FIRST_INPUT_DERIVATIVE = 2, ///< is the kernel differentiable w.r.t. its inputs?
86  IS_NORMALIZED = 4 , ///< does k(x, x) = 1 hold for all inputs x?
87  SUPPORTS_VARIABLE_INPUT_SIZE = 8 ///< Input arguments must have same size, but not the same size in different calls to eval
88  };
89 
90  /// This statement declares the member m_features. See Core/Flags.h for details.
92 
95  }
98  }
99  bool isNormalized() const{
100  return m_features & IS_NORMALIZED;
101  }
104  }
105 
106  ///\brief Creates an internal state of the kernel.
107  ///
108  ///The state is needed when the derivatives are to be
109  ///calculated. Eval can store a state which is then reused to speed up
110  ///the calculations of the derivatives. This also allows eval to be
111  ///evaluated in parallel!
112  virtual boost::shared_ptr<State> createState()const
113  {
114  SHARK_RUNTIME_CHECK(!hasFirstParameterDerivative() && !hasFirstInputDerivative(), "createState must be overridden by kernels with derivatives");
115  return boost::shared_ptr<State>(new EmptyState());
116  }
117 
118  ///////////////////////////////////////////SINGLE ELEMENT INTERFACE///////////////////////////////////////////
119  // By default, this is mapped to the batch case.
120 
121  /// \brief Evaluates the kernel function.
122  virtual double eval(ConstInputReference x1, ConstInputReference x2) const{
123  RealMatrix res;
124  BatchInputType b1 = Traits::createBatch(x1,1);
125  BatchInputType b2 = Traits::createBatch(x2,1);
126  getBatchElement(b1,0) = x1;
127  getBatchElement(b2,0) = x2;
128  eval(b1, b2, res);
129  return res(0, 0);
130  }
131 
132  /// \brief Convenience operator which evaluates the kernel function.
133  inline double operator () (ConstInputReference x1, ConstInputReference x2) const {
134  return eval(x1, x2);
135  }
136 
137  //////////////////////////////////////BATCH INTERFACE///////////////////////////////////////////
138 
139  /// \brief Evaluates the subset of the KernelGram matrix which is defined by X1(rows) and X2 (columns).
140  ///
141  /// The result matrix is filled in with the values result(i,j) = kernel(x1[i], x2[j]);
142  /// The State object is filled in with data used in subsequent derivative computations.
143  virtual void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result, State& state) const = 0;
144 
145  /// \brief Evaluates the subset of the KernelGram matrix which is defined by X1(rows) and X2 (columns).
146  ///
147  /// The result matrix is filled in with the values result(i,j) = kernel(x1[i], x2[j]);
148  virtual void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result) const {
149  boost::shared_ptr<State> state = createState();
150  eval(batchX1, batchX2, result, *state);
151  }
152 
153  /// \brief Evaluates the subset of the KernelGram matrix which is defined by X1(rows) and X2 (columns).
154  ///
155  /// Convenience operator.
156  /// The result matrix is filled in with the values result(i,j) = kernel(x1[i], x2[j]);
157  inline RealMatrix operator () (ConstBatchInputReference batchX1, ConstBatchInputReference batchX2) const{
158  RealMatrix result;
159  eval(batchX1, batchX2, result);
160  return result;
161  }
162 
163  /// \brief Computes the gradient of the parameters as a weighted sum over the gradient of all elements of the batch.
164  ///
165  /// The default implementation throws a "not implemented" exception.
167  ConstBatchInputReference batchX1,
168  ConstBatchInputReference batchX2,
169  RealMatrix const& coefficients,
170  State const& state,
171  RealVector& gradient
172  ) const {
174  }
175 
176  /// \brief Calculates the derivative of the inputs X1 (only x1!).
177  ///
178  /// The i-th row of the resulting matrix is a weighted sum of the form:
179  /// c[i,0] * k'(x1[i], x2[0]) + c[i,1] * k'(x1[i], x2[1]) + ... + c[i,n] * k'(x1[i], x2[n]).
180  ///
181  /// The default implementation throws a "not implemented" exception.
182  virtual void weightedInputDerivative(
183  ConstBatchInputReference batchX1,
184  ConstBatchInputReference batchX2,
185  RealMatrix const& coefficientsX2,
186  State const& state,
187  BatchInputType& gradient
188  ) const {
190  }
191 
192 
193  //////////////////////////////////NORMS AND DISTANCES/////////////////////////////////
194 
195  /// Computes the squared distance in the kernel induced feature space.
196  virtual double featureDistanceSqr(ConstInputReference x1, ConstInputReference x2) const{
197  if (isNormalized()){
198  double k12 = eval(x1, x2);
199  return (2.0 - 2.0 * k12);
200  } else {
201  double k11 = eval(x1, x1);
202  double k12 = eval(x1, x2);
203  double k22 = eval(x2, x2);
204  return (k11 - 2.0 * k12 + k22);
205  }
206  }
207 
208  virtual RealMatrix featureDistanceSqr(ConstBatchInputReference batchX1,ConstBatchInputReference batchX2) const{
209  std::size_t sizeX1 = batchSize(batchX1);
210  std::size_t sizeX2 = batchSize(batchX2);
211  RealMatrix result=(*this)(batchX1,batchX2);
212  result *= -2.0;
213  if (isNormalized()){
214  noalias(result) += 2.0;
215  } else {
216  //compute self-product
217  RealVector kx2(sizeX2);
218  for(std::size_t i = 0; i != sizeX2;++i){
219  kx2(i)=eval(getBatchElement(batchX2,i),getBatchElement(batchX2,i));
220  }
221  for(std::size_t j = 0; j != sizeX1;++j){
222  double kx1=eval(getBatchElement(batchX1,j),getBatchElement(batchX1,j));
223  noalias(row(result,j)) += kx1 + kx2;
224  }
225  }
226  return result;
227  }
228 };
229 
230 
231 }
232 #endif