35 #ifndef SHARK_MODELS_KERNELS_MODEL_KERNEL_H 36 #define SHARK_MODELS_KERNELS_MODEL_KERNEL_H 43 #include <boost/scoped_ptr.hpp> 48 template<
class InputType,
class IntermediateType>
49 class ModelKernelImpl :
public AbstractKernelFunction<InputType>
52 typedef AbstractKernelFunction<InputType> base_type;
57 typedef AbstractKernelFunction<IntermediateType> Kernel;
58 typedef AbstractModel<InputType,IntermediateType> Model;
60 struct InternalState:
public State{
61 boost::shared_ptr<State> kernelStateX1X2;
62 boost::shared_ptr<State> kernelStateX2X1;
63 boost::shared_ptr<State> modelStateX1;
64 boost::shared_ptr<State> modelStateX2;
70 ModelKernelImpl(Kernel* kernel, Model* model):mpe_kernel(kernel),mpe_model(model){
71 if(kernel->hasFirstParameterDerivative()
72 && kernel->hasFirstInputDerivative()
73 && model->hasFirstParameterDerivative())
78 std::string name()
const 79 {
return "ModelKernel"; }
81 std::size_t numberOfParameters()
const{
82 return mpe_kernel->numberOfParameters() + mpe_model->numberOfParameters();
84 RealVector parameterVector()
const{
85 return mpe_kernel->parameterVector() | mpe_model->parameterVector();
87 void setParameterVector(RealVector
const& newParameters){
88 SIZE_CHECK(newParameters.size() == numberOfParameters());
89 std::size_t kParams =mpe_kernel->numberOfParameters();
90 mpe_kernel->setParameterVector(subrange(newParameters,0,kParams));
91 mpe_model->setParameterVector(subrange(newParameters,kParams,newParameters.size()));
94 boost::shared_ptr<State> createState()
const{
95 InternalState* s =
new InternalState();
96 boost::shared_ptr<State> sharedState(s);
97 s->kernelStateX1X2 = mpe_kernel->createState();
98 s->kernelStateX2X1 = mpe_kernel->createState();
99 s->modelStateX1 = mpe_model->createState();
100 s->modelStateX2 = mpe_model->createState();
104 double eval(ConstInputReference x1, ConstInputReference x2)
const{
105 return mpe_kernel->eval((*mpe_model)(x1),(*mpe_model)(x2));
108 void eval(ConstBatchInputReference x1, ConstBatchInputReference x2, RealMatrix& result, State& state)
const{
109 InternalState& s=state.toState<InternalState>();
110 mpe_model->eval(x1,s.intermediateX1,*s.modelStateX1);
111 mpe_model->eval(x2,s.intermediateX2,*s.modelStateX2);
112 mpe_kernel->eval(s.intermediateX2,s.intermediateX1,result,*s.kernelStateX2X1);
113 mpe_kernel->eval(s.intermediateX1,s.intermediateX2,result,*s.kernelStateX1X2);
117 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result)
const{
118 return mpe_kernel->eval((*mpe_model)(batchX1),(*mpe_model)(batchX2),result);
121 void weightedParameterDerivative(
122 ConstBatchInputReference batchX1,
123 ConstBatchInputReference batchX2,
124 RealMatrix
const& coefficients,
128 gradient.resize(numberOfParameters());
129 InternalState
const& s=state.toState<InternalState>();
132 RealVector kernelGrad;
133 mpe_kernel->weightedParameterDerivative(
134 s.intermediateX1,s.intermediateX2,
135 coefficients,*s.kernelStateX1X2,kernelGrad
139 mpe_kernel->weightedInputDerivative(
140 s.intermediateX1,s.intermediateX2,
141 coefficients,*s.kernelStateX1X2,inputDerivativeX1
143 mpe_kernel->weightedInputDerivative(
144 s.intermediateX2,s.intermediateX1,
145 trans(coefficients),*s.kernelStateX2X1,inputDerivativeX2
149 RealVector modelGradX1,modelGradX2;
150 mpe_model->weightedParameterDerivative(batchX1,s.intermediateX1, inputDerivativeX1,*s.modelStateX1,modelGradX1);
151 mpe_model->weightedParameterDerivative(batchX2,s.intermediateX2, inputDerivativeX2,*s.modelStateX2,modelGradX2);
152 noalias(gradient) = kernelGrad | (modelGradX1+modelGradX2);
156 SHARK_RUNTIME_CHECK(mpe_kernel,
"The kernel function is NULL, kernel needs to be constructed prior to read in");
157 SHARK_RUNTIME_CHECK(mpe_model,
"The model is NULL, model needs to be constructed prior to read in");
192 template<
class InputType=RealVector>
201 template<
class IntermediateType>
205 ):m_wrapper(new detail::ModelKernelImpl<
InputType,IntermediateType>(kernel,model)){
208 if(m_wrapper->hasFirstParameterDerivative())
214 {
return "ModelKernel"; }
218 return m_wrapper->numberOfParameters();
222 return m_wrapper->parameterVector();
226 m_wrapper->setParameterVector(newParameters);
231 return m_wrapper->createState();
235 double eval(ConstInputReference x1, ConstInputReference x2)
const{
236 return m_wrapper->eval(x1,x2);
240 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result,
State& state)
const{
241 return m_wrapper->eval(batchX1,batchX2,result,state);
244 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result)
const{
245 m_wrapper->eval(batchX1,batchX2,result);
253 ConstBatchInputReference batchX1,
254 ConstBatchInputReference batchX2,
255 RealMatrix
const& coefficients,
259 m_wrapper->weightedParameterDerivative(batchX1,batchX2,coefficients,state,gradient);
271 boost::scoped_ptr<AbstractKernelFunction<InputType> > m_wrapper;