CrossValidationError.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief cross-validation error for selection of hyper-parameters
6 
7 
8  *
9  *
10  * \author T. Glasmachers, O. Krause
11  * \date 2007-2012
12  *
13  *
14  * \par Copyright 1995-2017 Shark Development Team
15  *
16  * <BR><HR>
17  * This file is part of Shark.
18  * <http://shark-ml.org/>
19  *
20  * Shark is free software: you can redistribute it and/or modify
21  * it under the terms of the GNU Lesser General Public License as published
22  * by the Free Software Foundation, either version 3 of the License, or
23  * (at your option) any later version.
24  *
25  * Shark is distributed in the hope that it will be useful,
26  * but WITHOUT ANY WARRANTY; without even the implied warranty of
27  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
28  * GNU Lesser General Public License for more details.
29  *
30  * You should have received a copy of the GNU Lesser General Public License
31  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
32  *
33  */
34 //===========================================================================
35 
36 #ifndef SHARK_OBJECTIVEFUNCTIONS_CROSSVALIDATIONERROR_H
37 #define SHARK_OBJECTIVEFUNCTIONS_CROSSVALIDATIONERROR_H
38 
44 
45 namespace shark {
46 
47 
48 ///
49 /// \brief Cross-validation error for selection of hyper-parameters.
50 ///
51 /// \par
52 /// The cross-validation error is useful for evaluating
53 /// how well a model performs on a problem. It is regularly
54 /// used for model selection.
55 ///
56 /// \par
57 /// In Shark, the cross-validation procedure is abstracted
58 /// as follows:
59 /// First, the given point is written into an IParameterizable
60 /// object (such as a regularizer or a trainer). Then a model
61 /// is trained with a trainer with the given settings on a
62 /// number of folds and evaluated on the corresponding validation
63 /// sets with a cost function. The average cost function value
64 /// over all folds is returned.
65 ///
66 /// \par
67 /// Thus, the cross-validation procedure requires a "meta"
68 /// IParameterizable object, a model, a trainer, a data set,
69 /// and a cost function.
70 ///
71 template<class ModelTypeT, class LabelTypeT = typename ModelTypeT::OutputType>
73 {
74 public:
75  typedef typename ModelTypeT::InputType InputType;
76  typedef typename ModelTypeT::OutputType OutputType;
77  typedef LabelTypeT LabelType;
80  typedef ModelTypeT ModelType;
83 private:
85 
86 
87  FoldsType m_folds;
88  IParameterizable<>* mep_meta;
89  ModelType* mep_model;
90  TrainerType* mep_trainer;
91  CostType* mep_cost;
92 
93 public:
94 
96  FoldsType const& dataFolds,
97  IParameterizable<>* meta,
98  ModelType* model,
99  TrainerType* trainer,
100  CostType* cost)
101  : m_folds(dataFolds)
102  , mep_meta(meta)
103  , mep_model(model)
104  , mep_trainer(trainer)
105  , mep_cost(cost)
106  { }
107 
108  /// \brief From INameable: return the class name.
109  std::string name() const
110  {
111  return "CrossValidationError<"
112  + mep_model->name() + ","
113  + mep_trainer->name() + ","
114  + mep_cost->name() + ">";
115  }
116 
117  std::size_t numberOfVariables()const{
118  return mep_meta->numberOfParameters();
119  }
120 
121  /// Evaluate the cross-validation error:
122  /// train sub-models, evaluate objective,
123  /// return the average.
124  double eval(RealVector const& parameters) const {
125  this->m_evaluationCounter++;
126  mep_meta->setParameterVector(parameters);
127 
128  double ret = 0.0;
129  for (size_t setID=0; setID != m_folds.size(); ++setID) {
130  DatasetType train = m_folds.training(setID);
131  DatasetType validation = m_folds.validation(setID);
132  mep_trainer->train(*mep_model, train);
133  Data<OutputType> output = (*mep_model)(validation.inputs());
134  ret += mep_cost->eval(validation.labels(), output);
135  }
136  return ret / m_folds.size();
137  }
138 };
139 
140 
141 }
142 #endif