ScaledKernel.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief A kernel function that wraps a member kernel and multiplies it by a scalar.
6  *
7  *
8  *
9  * \author M. Tuma, T. Glasmachers, O. Krause
10  * \date 2012
11  *
12  *
13  * \par Copyright 1995-2017 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://shark-ml.org/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_MODELS_KERNELS_SCALED_KERNEL_H
36 #define SHARK_MODELS_KERNELS_SCALED_KERNEL_H
37 
38 
40 namespace shark {
41 
42 
43 /// \brief Scaled version of a kernel function
44 ///
45 /// For a positive definite kernel k, the scaled kernel
46 /// \f[ \tilde k(x_1, x_2) := c k(x_1, x_2) \f]
47 /// is again a positive definite kernel function as long as \f$ c > 0 \f$.
48 template<class InputType=RealVector>
49 class ScaledKernel : public AbstractKernelFunction<InputType>
50 {
51 private:
53 public:
57 
59  : m_base( base ),
60  m_factor( factor )
61  {
62  RANGE_CHECK( factor > 0 );
63  SHARK_ASSERT( base != NULL );
68  }
69 
70  /// \brief From INameable: return the class name.
71  std::string name() const
72  { return "ScaledKernel"; }
73 
74  RealVector parameterVector() const {
75  return m_base->parameterVector();
76  }
77  void setParameterVector(RealVector const& newParameters) {
78  m_base->setParameterVector(newParameters);
79  }
80 
81  std::size_t numberOfParameters() const {
82  return m_base->numberOfParameters();
83  }
84 
85  ///\brief creates the internal state of the kernel
86  boost::shared_ptr<State> createState()const{
87  return m_base->createState();
88  }
89 
90  const double factor() {
91  return m_factor;
92  }
93  void setFactor( double f ) {
94  RANGE_CHECK( f > 0 );
95  m_factor = f;
96  }
97 
98  const base_type* base() const {
99  return m_base;
100  }
101 
102  double eval(ConstInputReference x1, ConstInputReference x2) const {
103  SIZE_CHECK(x1.size() == x2.size());
104  return m_factor * m_base->eval(x1, x2);
105  }
106 
107  void eval(ConstBatchInputReference x1, ConstBatchInputReference x2, RealMatrix& result) const{
108  m_base->eval(x1, x2,result);
109  result *= m_factor;
110  }
111 
112  void eval(ConstBatchInputReference x1, ConstBatchInputReference x2, RealMatrix& result, State& state) const{
113  m_base->eval(x1, x2,result,state);
114  result *= m_factor;
115  }
116 
117  /// calculates the weighted derivate w.r.t. the parameters of the base kernel
119  ConstBatchInputReference batchX1,
120  ConstBatchInputReference batchX2,
121  RealMatrix const& coefficients,
122  State const& state,
123  RealVector& gradient
124  ) const{
125  m_base->weightedParameterDerivative( batchX1, batchX2, coefficients, state, gradient );
126  gradient *= m_factor;
127  }
128  /// calculates the weighted derivate w.r.t. argument \f$ x_1 \f$
130  ConstBatchInputReference batchX1,
131  ConstBatchInputReference batchX2,
132  RealMatrix const& coefficientsX2,
133  State const& state,
134  BatchInputType& gradient
135  ) const{
136  SIZE_CHECK(coefficientsX2.size1() == batchSize(batchX1));
137  SIZE_CHECK(coefficientsX2.size2() == batchSize(batchX2));
138  m_base->weightedInputDerivative( batchX1, batchX2, coefficientsX2, state, gradient );
139  gradient *= m_factor;
140  }
141 
142  void read(InArchive& ar){
143  ar >> m_factor;
144  ar >> *m_base;
145  }
146 
147  /// \brief The kernel does not serialize anything
148  void write(OutArchive& ar) const{
149  ar << m_factor;
150  //const cast needed to prevent warning
151  ar << const_cast<AbstractKernelFunction<InputType> const&>(*m_base);
152  }
153 
154 protected:
156  double m_factor; ///< scaling factor
157 };
158 
161 
162 }
163 #endif