35 #ifndef SHARK_MODELS_KERNELS_MONOMIAL_KERNEL_H 36 #define SHARK_MODELS_KERNELS_MONOMIAL_KERNEL_H 48 template<
class InputType=RealVector>
54 struct InternalState:
public State{
56 RealMatrix exponentedProd;
58 void resize(std::size_t sizeX1, std::size_t sizeX2){
59 base.resize(sizeX1, sizeX2);
60 exponentedProd.resize(sizeX1, sizeX2);
81 {
return "MonomialKernel"; }
95 return boost::shared_ptr<State>(
new InternalState());
99 double eval(ConstInputReference x1, ConstInputReference x2)
const{
101 double prod=inner_prod(x1, x2);
105 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result)
const{
106 SIZE_CHECK(batchX1.size2() == batchX2.size2());
107 std::size_t sizeX1 = batchX1.size1();
108 std::size_t sizeX2 = batchX2.size1();
109 result.resize(sizeX1,sizeX2);
110 noalias(result) = prod(batchX1,trans(batchX2));
115 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result,
State& state)
const{
116 SIZE_CHECK(batchX1.size2() == batchX2.size2());
117 std::size_t sizeX1 = batchX1.size1();
118 std::size_t sizeX2 = batchX2.size1();
119 result.resize(sizeX1,sizeX2);
122 InternalState& s = state.
toState<InternalState>();
123 s.resize(sizeX1,sizeX2);
126 noalias(s.base) = prod(batchX1,trans(batchX2));
131 noalias(result) = s.base;
133 noalias(s.exponentedProd) = result;
140 ConstBatchInputReference batchX1,
141 ConstBatchInputReference batchX2,
142 RealMatrix
const& coefficients,
146 SIZE_CHECK(batchX1.size2() == batchX2.size2());
151 ConstBatchInputReference batchX1,
152 ConstBatchInputReference batchX2,
153 RealMatrix
const& coefficientsX2,
155 BatchInputType& gradient
158 std::size_t sizeX1 = batchX1.size1();
159 std::size_t sizeX2 = batchX2.size1();
160 gradient.resize(sizeX1,batchX1.size2());
161 InternalState
const& s = state.
toState<InternalState>();
164 SIZE_CHECK(batchX1.size2() == batchX2.size2());
167 SIZE_CHECK(s.exponentedProd.size1() == sizeX1);
168 SIZE_CHECK(s.exponentedProd.size2() == sizeX2);
172 RealMatrix weights = coefficientsX2 * safe_div(s.exponentedProd,s.base,0.0);
176 noalias(gradient) =
m_exponent * prod(weights,batchX2);