35 #ifndef SHARK_MODELS_KERNELS_SUBRANGE_KERNEL_H 36 #define SHARK_MODELS_KERNELS_SUBRANGE_KERNEL_H 44 template<
class InputType>
45 class SubrangeKernelWrapper :
public AbstractKernelFunction<InputType>{
47 typedef AbstractKernelFunction<InputType> base_type;
53 SubrangeKernelWrapper(AbstractKernelFunction<InputType>* kernel,std::size_t start, std::size_t end)
54 :m_kernel(kernel),m_start(start),m_end(end){
55 if(kernel->hasFirstParameterDerivative())
57 if(kernel->hasFirstInputDerivative())
62 std::string
name()
const 63 {
return "SubrangeKernelWrapper"; }
66 return m_kernel->parameterVector();
70 m_kernel->setParameterVector(newParameters);
74 return m_kernel->numberOfParameters();
79 return m_kernel->createState();
82 double eval(ConstInputReference x1, ConstInputReference x2)
const{
83 return m_kernel->eval(blas::subrange(x1,m_start,m_end),blas::subrange(x2,m_start,m_end));
86 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result, State& state)
const{
87 m_kernel->eval(columns(batchX1,m_start,m_end),columns(batchX2,m_start,m_end),result,state);
90 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result)
const{
91 m_kernel->eval(columns(batchX1,m_start,m_end),columns(batchX2,m_start,m_end),result);
95 ConstBatchInputReference batchX1,
96 ConstBatchInputReference batchX2,
97 RealMatrix
const& coefficients,
101 m_kernel->weightedParameterDerivative(
102 columns(batchX1,m_start,m_end),
103 columns(batchX2,m_start,m_end),
110 ConstBatchInputReference batchX1,
111 ConstBatchInputReference batchX2,
112 RealMatrix
const& coefficientsX2,
114 BatchInputType& gradient
116 BatchInputType temp(gradient.size1(),m_end-m_start);
117 m_kernel->weightedInputDerivative(
118 columns(batchX1,m_start,m_end),
119 columns(batchX2,m_start,m_end),
124 ensure_size(gradient,batchX1.size1(),batchX2.size2());
126 noalias(columns(gradient,m_start,m_end)) = temp;
136 AbstractKernelFunction<InputType>* m_kernel;
141 template<
class InputType>
142 class SubrangeKernelBase
146 template<
class Kernels,
class Ranges>
147 SubrangeKernelBase(Kernels
const& kernels, Ranges
const& ranges){
149 for(std::size_t i = 0; i != kernels.size(); ++i){
150 m_kernelWrappers.push_back(
151 SubrangeKernelWrapper<InputType>(kernels[i],ranges[i].first,ranges[i].second)
156 std::vector<AbstractKernelFunction<InputType>* > makeKernelVector(){
157 std::vector<AbstractKernelFunction<InputType>* > kernels(m_kernelWrappers.size());
158 for(std::size_t i = 0; i != m_kernelWrappers.size(); ++i)
159 kernels[i] = & m_kernelWrappers[i];
163 std::vector<SubrangeKernelWrapper <InputType> > m_kernelWrappers;
188 template<
class InputType,
class InnerKernel=WeightedSumKernel<InputType> >
190 :
private detail::SubrangeKernelBase<InputType>
194 typedef detail::SubrangeKernelBase<InputType> base_type1;
199 {
return "SubrangeKernel"; }
201 template<
class Kernels,
class Ranges>
203 : base_type1(kernels,ranges)
204 , InnerKernel(base_type1::makeKernelVector()){}