36 #ifndef SHARK_OBJECTIVEFUNCTIONS_CROSSVALIDATIONERROR_H 37 #define SHARK_OBJECTIVEFUNCTIONS_CROSSVALIDATIONERROR_H 71 template<
class ModelTypeT,
class LabelTypeT =
typename ModelTypeT::OutputType>
90 TrainerType* mep_trainer;
96 FoldsType
const& dataFolds,
104 , mep_trainer(trainer)
111 return "CrossValidationError<" 112 + mep_model->name() +
"," 113 + mep_trainer->
name() +
"," 114 + mep_cost->
name() +
">";
124 double eval(RealVector
const& parameters)
const {
129 for (
size_t setID=0; setID != m_folds.
size(); ++setID) {
130 DatasetType train = m_folds.
training(setID);
131 DatasetType validation = m_folds.
validation(setID);
132 mep_trainer->
train(*mep_model, train);
134 ret += mep_cost->
eval(validation.
labels(), output);
136 return ret / m_folds.
size();