Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Class list
Global functions
FAQ
Showroom
Main Page
Related Pages
Modules
Classes
include
shark
ObjectiveFunctions
CrossValidationError.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
* \brief cross-validation error for selection of hyper-parameters
4
*
5
*
6
* \author T. Glasmachers, O. Krause
7
* \date 2007-2012
8
*
9
*
10
* <BR><HR>
11
* This file is part of Shark. This library is free software;
12
* you can redistribute it and/or modify it under the terms of the
13
* GNU General Public License as published by the Free Software
14
* Foundation; either version 3, or (at your option) any later version.
15
*
16
* This library is distributed in the hope that it will be useful,
17
* but WITHOUT ANY WARRANTY; without even the implied warranty of
18
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19
* GNU General Public License for more details.
20
*
21
* You should have received a copy of the GNU General Public License
22
* along with this library; if not, see <http://www.gnu.org/licenses/>.
23
*
24
*/
25
//===========================================================================
26
27
#ifndef SHARK_OBJECTIVEFUNCTIONS_CROSSVALIDATIONERROR_H
28
#define SHARK_OBJECTIVEFUNCTIONS_CROSSVALIDATIONERROR_H
29
30
#include <
shark/ObjectiveFunctions/DataObjectiveFunction.h
>
31
#include <
shark/Algorithms/Trainers/AbstractTrainer.h
>
32
#include <
shark/Algorithms/AbstractSingleObjectiveOptimizer.h
>
33
#include <
shark/ObjectiveFunctions/AbstractCost.h
>
34
#include <
shark/Data/CVDatasetTools.h
>
35
36
namespace
shark {
37
38
39
///
40
/// \brief Cross-validation error for selection of hyper-parameters.
41
///
42
/// \par
43
/// The cross-validation error is useful for evaluating
44
/// how well a model performs on a problem. It is regularly
45
/// used for model selection.
46
///
47
/// \par
48
/// In Shark, the cross-validation procedure is abstracted
49
/// as follows:
50
/// First, the given point is written into an IParameterizable
51
/// object (such as a regularizer or a trainer). Then a model
52
/// is trained with a trainer with the given settings on a
53
/// number of folds and evaluated on the corresponding validation
54
/// sets with a cost function. The average cost function value
55
/// over all folds is returned.
56
///
57
/// \par
58
/// Thus, the cross-validation procedure requires a "meta"
59
/// IParameterizable object, a model, a trainer, a data set,
60
/// and a cost function.
61
///
62
template
<
class
ModelTypeT,
class
LabelTypeT =
typename
ModelTypeT::OutputType>
63
class
CrossValidationError
:
public
AbstractObjectiveFunction
< VectorSpace<double>, double >
64
{
65
public
:
66
typedef
typename
ModelTypeT::InputType
InputType
;
67
typedef
typename
ModelTypeT::OutputType
OutputType
;
68
typedef
LabelTypeT
LabelType
;
69
typedef
LabeledData<InputType, LabelType>
DatasetType
;
70
typedef
CVFolds<DatasetType>
FoldsType
;
71
typedef
ModelTypeT
ModelType
;
72
typedef
AbstractTrainer<ModelType, LabelType>
TrainerType
;
73
typedef
AbstractCost<LabelType, OutputType>
CostType
;
74
private
:
75
typedef
AbstractObjectiveFunction< VectorSpace<double>
,
double
>
base_type
;
76
77
78
FoldsType
m_folds;
79
IParameterizable
* mep_meta;
80
ModelType
* mep_model;
81
TrainerType
* mep_trainer;
82
CostType
* mep_cost;
83
84
public
:
85
86
CrossValidationError
(
87
FoldsType
const
& dataFolds,
88
IParameterizable
* meta,
89
ModelType
* model,
90
TrainerType
* trainer,
91
CostType
* cost)
92
: m_folds(dataFolds)
93
, mep_meta(meta)
94
, mep_model(model)
95
, mep_trainer(trainer)
96
, mep_cost(cost)
97
{ }
98
99
/// \brief From INameable: return the class name.
100
std::string
name
()
const
101
{
102
return
"CrossValidationError<"
103
+ mep_model->name() +
","
104
+ mep_trainer->
name
() +
","
105
+ mep_cost->
name
() +
">"
;
106
}
107
108
/// configure the cross validation
109
void
configure
(
const
PropertyTree
& node ) {}
110
111
std::size_t
numberOfVariables
()
const
{
112
return
mep_meta->
numberOfParameters
();
113
}
114
115
/// Evaluate the cross-validation error:
116
/// train sub-models, evaluate objective,
117
/// return the average.
118
double
eval
(RealVector
const
&
parameters
)
const
{
119
this->
m_evaluationCounter
++;
120
mep_meta->
setParameterVector
(parameters);
121
122
double
ret = 0.0;
123
for
(
size_t
setID=0; setID != m_folds.
size
(); ++setID) {
124
DatasetType
train = m_folds.
training
(setID);
125
DatasetType
validation = m_folds.
validation
(setID);
126
mep_trainer->
train
(*mep_model, train);
127
Data<OutputType>
output = (*mep_model)(validation.
inputs
());
128
ret += mep_cost->
eval
(validation.
labels
(), output);
129
}
130
return
ret / m_folds.
size
();
131
}
132
};
133
134
135
}
136
#endif