35 #ifndef SHARK_MODELS_KERNELS_WEIGHTED_SUM_KERNEL_H 36 #define SHARK_MODELS_KERNELS_WEIGHTED_SUM_KERNEL_H 41 #include <boost/utility/enable_if.hpp> 62 template<
class InputType=RealVector>
68 struct InternalState:
public State{
70 std::vector<RealMatrix> kernelResults;
71 std::vector<boost::shared_ptr<State> > kernelStates;
73 InternalState(std::size_t numSubKernels)
74 :kernelResults(numSubKernels),kernelStates(numSubKernels){}
76 void resize(std::size_t sizeX1, std::size_t sizeX2){
77 result.resize(sizeX1, sizeX2);
78 for(std::size_t i = 0; i != kernelResults.size(); ++i){
79 kernelResults[i].resize(sizeX1, sizeX2);
89 SHARK_RUNTIME_CHECK( base.size() > 0,
"[WeightedSumKernel::WeightedSumKernel] There should be at least one sub-kernel.");
91 m_base.resize( base.size() );
94 for (std::size_t i=0; i !=
m_base.size() ; i++) {
96 m_base[i].kernel = base[i];
98 m_base[i].adaptive =
false;
104 for (
unsigned int i=0; i<
m_base.size(); i++ ){
105 if ( !
m_base[i].kernel->hasFirstParameterDerivative() ) {
106 hasFirstParameterDerivative =
false;
111 for (
unsigned int i=0; i<
m_base.size(); i++ ){
112 if ( !
m_base[i].kernel->hasFirstInputDerivative() ) {
113 hasFirstInputDerivative =
false;
118 if ( hasFirstParameterDerivative )
121 if ( hasFirstInputDerivative )
127 {
return "WeightedSumKernel"; }
131 return m_base[index].adaptive;
135 m_base[index].adaptive = b;
140 for (std::size_t i=0; i!=
m_base.size(); i++)
148 return m_base[index].weight;
161 std::size_t index = 0;
162 for (; index !=
m_base.size()-1; index++){
166 for (std::size_t i=0; i !=
m_base.size(); i++){
168 std::size_t n =
m_base[i].kernel->numberOfParameters();
169 subrange(ret,index,index+n) =
m_base[i].kernel->parameterVector();
178 InternalState* state =
new InternalState(
m_base.size());
179 for(std::size_t i = 0; i !=
m_base.size(); ++i){
180 state->kernelStates[i]=
m_base[i].kernel->createState();
182 return boost::shared_ptr<State>(state);
191 std::size_t index = 0;
192 for (; index !=
m_base.size()-1; index++){
193 double w = newParameters(index);
194 m_base[index+1].weight = std::exp(w);
195 m_weightsum +=
m_base[index+1].weight;
198 for (std::size_t i=0; i !=
m_base.size(); i++){
200 std::size_t n =
m_base[i].kernel->numberOfParameters();
201 m_base[i].kernel->setParameterVector(subrange(newParameters,index,index+n));
213 double eval(ConstInputReference x1, ConstInputReference x2)
const{
214 double numerator = 0.0;
215 for (std::size_t i=0; i !=
m_base.size(); i++){
216 double result =
m_base[i].kernel->eval(x1, x2);
217 numerator +=
m_base[i].weight*result;
225 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result)
const{
228 ensure_size(result,sizeX1,sizeX2);
231 RealMatrix kernelResult(sizeX1,sizeX2);
232 for (std::size_t i = 0; i !=
m_base.size(); i++){
233 m_base[i].kernel->eval(batchX1, batchX2,kernelResult);
234 result +=
m_base[i].weight*kernelResult;
243 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result,
State& state)
const{
246 ensure_size(result,sizeX1,sizeX2);
249 InternalState& s = state.
toState<InternalState>();
250 s.resize(sizeX1,sizeX2);
252 for (std::size_t i=0; i !=
m_base.size(); i++){
253 m_base[i].kernel->eval(batchX1,batchX2,s.kernelResults[i],*s.kernelStates[i]);
254 result +=
m_base[i].weight*s.kernelResults[i];
262 ConstBatchInputReference batchX1,
263 ConstBatchInputReference batchX2,
264 RealMatrix
const& coefficients,
270 std::size_t numKernels =
m_base.size();
272 InternalState
const& s = state.
toState<InternalState>();
280 double numeratorSum = sum(coefficients * s.result);
283 double summedK=sum(coefficients * s.kernelResults[i]);
284 gradient(i-1) =
m_base[i].weight * (summedK *
m_weightsum - numeratorSum) / sumSquared;
287 std::size_t gradPos = m_adaptWeights ? numKernels-1: 0;
288 RealVector kernelGrad;
289 for (std::size_t i=0; i != numKernels; i++) {
292 m_base[i].kernel->weightedParameterDerivative(batchX1,batchX2,coefficients,*s.kernelStates[i],kernelGrad);
293 std::size_t n = kernelGrad.size();
294 noalias(subrange(gradient,gradPos,gradPos+n)) = coeff * kernelGrad;
307 ConstBatchInputReference batchX1,
308 ConstBatchInputReference batchX2,
309 RealMatrix
const& coefficientsX2,
311 BatchInputType& gradient
315 weightedInputDerivativeImpl<BatchInputType>(batchX1,batchX2,coefficientsX2,state,gradient);
319 for(std::size_t i = 0;i !=
m_base.size(); ++i ){
322 ar >> *(
m_base[i].kernel);
328 for(std::size_t i=0;i!=
m_base.size();++i){
331 ar << const_cast<AbstractKernelFunction<InputType>
const&>(*(
m_base[i].kernel));
348 for (std::size_t i=0; i !=
m_base.size(); i++)
359 ConstBatchInputReference batchX1,
360 ConstBatchInputReference batchX2,
361 RealMatrix
const& coefficientsX2,
363 BatchInputType& gradient,
364 typename boost::enable_if<boost::is_same<T,RealMatrix > >::type* dummy = 0
366 std::size_t numKernels =
m_base.size();
367 InternalState
const& s = state.
toState<InternalState>();
371 m_base[0].kernel->weightedInputDerivative(batchX1, batchX2, coefficientsX2, *s.kernelStates[0], gradient);
373 BatchInputType kernelGrad;
374 for (std::size_t i=1; i != numKernels; i++){
375 m_base[i].kernel->weightedInputDerivative(batchX1, batchX2, coefficientsX2, *s.kernelStates[i], kernelGrad);
377 gradient += coeff * kernelGrad;
382 ConstBatchInputReference batchX1,
383 ConstBatchInputReference batchX2,
384 RealMatrix
const& coefficientsX2,
386 BatchInputType& gradient,
387 typename boost::disable_if<boost::is_same<T,RealMatrix > >::type* dummy = 0
389 throw SHARKEXCEPTION(
"[WeightedSumKernel::weightdInputDerivative] The used BatchInputType is no Vector");