32 #ifndef SHARK_ML_SVMLOGISTICINTERPRETATION_H 33 #define SHARK_ML_SVMLOGISTICINTERPRETATION_H 62 template<
class InputType = RealVector>
85 FoldsType
const &folds, KernelType *kernel,
89 , m_nhp(kernel->parameterVector().size()+1)
90 , m_nkp(kernel->parameterVector().size())
91 , m_numFolds(folds.size())
92 , m_numSamples(folds.dataset().numberOfElements())
94 , m_svmCIsUnconstrained(unconstrained)
95 , mep_svmStoppingCondition(stop_cond)
97 SHARK_RUNTIME_CHECK(kernel != NULL,
"[SvmLogisticInterpretation::SvmLogisticInterpretation] kernel is not allowed to be NULL");
98 SHARK_RUNTIME_CHECK(m_numFolds > 1,
"[SvmLogisticInterpretation::SvmLogisticInterpretation] please provide a meaningful number of folds for cross validation");
99 if (!m_svmCIsUnconstrained)
109 {
return "SvmLogisticInterpretation"; }
132 SHARK_RUNTIME_CHECK(m_nhp == parameters.size(),
"[SvmLogisticInterpretation::eval] wrong number of parameters");
134 double C_reg = (m_svmCIsUnconstrained ? std::exp(parameters(m_nkp)) : parameters(m_nkp));
144 if (mep_svmStoppingCondition != NULL) {
156 LinearModel<> logistic_model = fitLogistic(validation_dataset);
160 return logistic_loss(validation_dataset.
labels(),logistic_model(validation_dataset.
inputs()));
168 SHARK_RUNTIME_CHECK(m_nhp == parameters.size(),
"[SvmLogisticInterpretation::evalDerivative] wrong number of parameters");
170 double C_reg = (m_svmCIsUnconstrained ? std::exp(parameters(m_nkp)) : parameters(m_nkp));
174 std::vector< unsigned int > tmp_helper_labels(m_numSamples);
175 std::vector< RealVector > tmp_helper_preds(m_numSamples);
177 unsigned int next_label = 0;
179 RealMatrix all_validation_predict_derivs(m_numSamples, m_nhp);
196 if (mep_svmStoppingCondition != NULL) {
200 csvm_trainer.
train(svm, cur_train_data);
204 for (std::size_t j=0; j<cur_vsize; j++) {
206 tmp_helper_labels[next_label] = cur_vlabels.
element(j);
207 tmp_helper_preds[next_label] = cur_vscores.
element(j);
210 noalias(row(all_validation_predict_derivs, next_label)) = der;
218 LinearModel<> logistic_model = fitLogistic(validation_dataset);
222 derivative.resize(m_nhp);
225 std::size_t start = 0;
226 for(
auto const& batch: validation_dataset.
batches()){
227 std::size_t end = start+batch.size();
229 RealMatrix lossGradient;
230 error += logistic_loss.
evalDerivative(batch.label,logistic_model(batch.input),lossGradient);
231 noalias(derivative) += column(lossGradient,0) % rows(all_validation_predict_derivs,start,end);
245 optimizer.
init(error);
249 optimizer.
step(error);
253 return logistic_model;