ProductKernel.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Product of kernel functions.
6  *
7  *
8  *
9  * \author T. Glasmachers, O.Krause
10  * \date 2012
11  *
12  *
13  * \par Copyright 1995-2017 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://shark-ml.org/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_MODELS_KERNELS_PRODUCTKERNEL_H
36 #define SHARK_MODELS_KERNELS_PRODUCTKERNEL_H
37 
38 
40 
41 namespace shark{
42 
43 
44 ///
45 /// \brief Product of kernel functions.
46 ///
47 /// \par
48 /// The product of any number of kernels is again a valid kernel.
49 /// This class supports a kernel af the form
50 /// \f$ k(x, x') = k_1(x, x') \cdot k_2(x, x') \cdot \dots \cdot k_n(x, x') \f$
51 /// for any number of base kernels. All kernels need to be defined
52 /// on the same input space.
53 ///
54 /// \par
55 /// Derivatives are currently not implemented. Only the plain
56 /// kernel value can be computed. Everyone is free to add this
57 /// functionality :)
58 ///
59 template<class InputType>
60 class ProductKernel : public AbstractKernelFunction<InputType>
61 {
62 private:
64 public:
69  /// \brief Default constructor.
71  // this->m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
72  // this->m_features |= base_type::HAS_SECOND_PARAMETER_DERIVATIVE;
73  // this->m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
74  // this->m_features |= base_type::HAS_SECOND_INPUT_DERIVATIVE;
75  this->m_features |= base_type::IS_NORMALIZED; // an "empty" product is a normalized kernel (k(x, x) = 1).
76  }
77 
78  /// \brief Constructor for a product of two kernels.
79  ProductKernel(SubKernel* k1, SubKernel* k2){
80  // this->m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
81  // this->m_features |= base_type::HAS_SECOND_PARAMETER_DERIVATIVE;
82  // this->m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
83  // this->m_features |= base_type::HAS_SECOND_INPUT_DERIVATIVE;
84  this->m_features |= base_type::IS_NORMALIZED; // an "empty" product is a normalized kernel (k(x, x) = 1).
85  addKernel(k1);
86  addKernel(k2);
87  }
88  ProductKernel(std::vector<SubKernel*> kernels){
89  // this->m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
90  // this->m_features |= base_type::HAS_SECOND_PARAMETER_DERIVATIVE;
91  // this->m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
92  // this->m_features |= base_type::HAS_SECOND_INPUT_DERIVATIVE;
93  this->m_features |= base_type::IS_NORMALIZED; // an "empty" product is a normalized kernel (k(x, x) = 1).
94  for(std::size_t i = 0; i != kernels.size(); ++i)
95  addKernel(kernels[i]);
96  }
97 
98  /// \brief From INameable: return the class name.
99  std::string name() const
100  { return "ProductKernel"; }
101 
102  /// \brief Add one more kernel to the expansion.
103  ///
104  /// \param k The pointer is expected to remain valid during the lifetime of the ProductKernel object.
105  ///
106  void addKernel(SubKernel* k){
107  SHARK_ASSERT(k != NULL);
108 
109  m_kernels.push_back(k);
111  if (! k->isNormalized()) this->m_features.reset(base_type::IS_NORMALIZED); // products of normalized kernels are normalized.
112  }
113 
114  RealVector parameterVector() const{
115  RealVector ret(m_numberOfParameters);
116  std::size_t pos = 0;
117  for(auto kernel: m_kernels){
118  auto const& params = kernel->parameterVector();
119  noalias(subrange(ret,pos, pos + params.size())) = params;
120  pos += params.size();
121  }
122  return ret;
123  }
124 
125  void setParameterVector(RealVector const& newParameters){
126  SIZE_CHECK(newParameters.size() == m_numberOfParameters);
127 
128  std::size_t pos = 0;
129  for(auto kernel: m_kernels){
130  std::size_t numParams = kernel->numberOfParameters();
131  kernel->setParameterVector(subrange(newParameters,pos, pos + numParams));
132  pos += numParams;
133  }
134  }
135 
136  std::size_t numberOfParameters() const{
137  return m_numberOfParameters;
138  }
139 
140  /// \brief evaluates the kernel function
141  ///
142  /// This function returns the product of all sub-kernels.
143  double eval(ConstInputReference x1, ConstInputReference x2) const{
144  double prod = 1.0;
145  for (std::size_t i=0; i<m_kernels.size(); i++)
146  prod *= m_kernels[i]->eval(x1, x2);
147  return prod;
148  }
149 
150  void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result) const{
151  std::size_t sizeX1 = batchSize(batchX1);
152  std::size_t sizeX2 = batchSize(batchX2);
153 
154  //evaluate first kernel to initialize the result
155  m_kernels[0]->eval(batchX1,batchX2,result);
156 
157  RealMatrix kernelResult(sizeX1,sizeX2);
158  for(std::size_t i = 1; i != m_kernels.size(); ++i){
159  m_kernels[i]->eval(batchX1,batchX2,kernelResult);
160  noalias(result) *= kernelResult;
161  }
162  }
163 
164  void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result, State& state) const{
165  eval(batchX1,batchX2,result);
166  }
167 
168  /// From ISerializable.
169  void read(InArchive& ar){
170  for(std::size_t i = 0;i != m_kernels.size(); ++i ){
171  ar >> *m_kernels[i];
172  }
173  ar >> m_numberOfParameters;
174  }
175 
176  /// From ISerializable.
177  void write(OutArchive& ar) const{
178  for(std::size_t i = 0;i != m_kernels.size(); ++i ){
179  ar << const_cast<AbstractKernelFunction<InputType> const&>(*m_kernels[i]);//prevent serialization warning
180  }
181  ar << m_numberOfParameters;
182  }
183 
184 protected:
185  std::vector<SubKernel*> m_kernels; ///< vector of sub-kernels
186  std::size_t m_numberOfParameters; ///< total number of parameters in the product (this is redundant information)
187 };
188 
189 
190 }
191 #endif