CrossValidationError.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  * \brief cross-validation error for selection of hyper-parameters
4  *
5  *
6  * \author T. Glasmachers, O. Krause
7  * \date 2007-2012
8  *
9  *
10  * <BR><HR>
11  * This file is part of Shark. This library is free software;
12  * you can redistribute it and/or modify it under the terms of the
13  * GNU General Public License as published by the Free Software
14  * Foundation; either version 3, or (at your option) any later version.
15  *
16  * This library is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19  * GNU General Public License for more details.
20  *
21  * You should have received a copy of the GNU General Public License
22  * along with this library; if not, see <http://www.gnu.org/licenses/>.
23  *
24  */
25 //===========================================================================
26 
27 #ifndef SHARK_OBJECTIVEFUNCTIONS_CROSSVALIDATIONERROR_H
28 #define SHARK_OBJECTIVEFUNCTIONS_CROSSVALIDATIONERROR_H
29 
35 
36 namespace shark {
37 
38 
39 ///
40 /// \brief Cross-validation error for selection of hyper-parameters.
41 ///
42 /// \par
43 /// The cross-validation error is useful for evaluating
44 /// how well a model performs on a problem. It is regularly
45 /// used for model selection.
46 ///
47 /// \par
48 /// In Shark, the cross-validation procedure is abstracted
49 /// as follows:
50 /// First, the given point is written into an IParameterizable
51 /// object (such as a regularizer or a trainer). Then a model
52 /// is trained with a trainer with the given settings on a
53 /// number of folds and evaluated on the corresponding validation
54 /// sets with a cost function. The average cost function value
55 /// over all folds is returned.
56 ///
57 /// \par
58 /// Thus, the cross-validation procedure requires a "meta"
59 /// IParameterizable object, a model, a trainer, a data set,
60 /// and a cost function.
61 ///
62 template<class ModelTypeT, class LabelTypeT = typename ModelTypeT::OutputType>
63 class CrossValidationError : public AbstractObjectiveFunction< VectorSpace<double>, double >
64 {
65 public:
66  typedef typename ModelTypeT::InputType InputType;
67  typedef typename ModelTypeT::OutputType OutputType;
68  typedef LabelTypeT LabelType;
71  typedef ModelTypeT ModelType;
74 private:
76 
77 
78  FoldsType m_folds;
79  IParameterizable* mep_meta;
80  ModelType* mep_model;
81  TrainerType* mep_trainer;
82  CostType* mep_cost;
83 
84 public:
85 
87  FoldsType const& dataFolds,
88  IParameterizable* meta,
89  ModelType* model,
90  TrainerType* trainer,
91  CostType* cost)
92  : m_folds(dataFolds)
93  , mep_meta(meta)
94  , mep_model(model)
95  , mep_trainer(trainer)
96  , mep_cost(cost)
97  { }
98 
99  /// \brief From INameable: return the class name.
100  std::string name() const
101  {
102  return "CrossValidationError<"
103  + mep_model->name() + ","
104  + mep_trainer->name() + ","
105  + mep_cost->name() + ">";
106  }
107 
108  /// configure the cross validation
109  void configure( const PropertyTree & node ) {}
110 
111  std::size_t numberOfVariables()const{
112  return mep_meta->numberOfParameters();
113  }
114 
115  /// Evaluate the cross-validation error:
116  /// train sub-models, evaluate objective,
117  /// return the average.
118  double eval(RealVector const& parameters) const {
119  this->m_evaluationCounter++;
120  mep_meta->setParameterVector(parameters);
121 
122  double ret = 0.0;
123  for (size_t setID=0; setID != m_folds.size(); ++setID) {
124  DatasetType train = m_folds.training(setID);
125  DatasetType validation = m_folds.validation(setID);
126  mep_trainer->train(*mep_model, train);
127  Data<OutputType> output = (*mep_model)(validation.inputs());
128  ret += mep_cost->eval(validation.labels(), output);
129  }
130  return ret / m_folds.size();
131  }
132 };
133 
134 
135 }
136 #endif