Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
include
shark
ObjectiveFunctions
ErrorFunction.h
Go to the documentation of this file.
1
/*!
2
*
3
*
4
* \brief error function for supervised learning
5
*
6
*
7
*
8
* \author T.Voss, T. Glasmachers, O.Krause
9
* \date 2010-2011
10
*
11
*
12
* \par Copyright 1995-2017 Shark Development Team
13
*
14
* <BR><HR>
15
* This file is part of Shark.
16
* <http://shark-ml.org/>
17
*
18
* Shark is free software: you can redistribute it and/or modify
19
* it under the terms of the GNU Lesser General Public License as published
20
* by the Free Software Foundation, either version 3 of the License, or
21
* (at your option) any later version.
22
*
23
* Shark is distributed in the hope that it will be useful,
24
* but WITHOUT ANY WARRANTY; without even the implied warranty of
25
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26
* GNU Lesser General Public License for more details.
27
*
28
* You should have received a copy of the GNU Lesser General Public License
29
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
30
*
31
*/
32
#ifndef SHARK_OBJECTIVEFUNCTIONS_ERRORFUNCTION_H
33
#define SHARK_OBJECTIVEFUNCTIONS_ERRORFUNCTION_H
34
35
36
#include <
shark/Models/AbstractModel.h
>
37
#include <
shark/ObjectiveFunctions/Loss/AbstractLoss.h
>
38
#include <
shark/ObjectiveFunctions/AbstractObjectiveFunction.h
>
39
#include <
shark/Data/Dataset.h
>
40
#include <
shark/Data/WeightedDataset.h
>
41
#include "Impl/FunctionWrapperBase.h"
42
43
#include <boost/scoped_ptr.hpp>
44
45
namespace
shark
{
46
47
///
48
/// \brief Objective function for supervised learning
49
///
50
/// \par
51
/// An ErrorFunction object is an objective function for
52
/// learning the parameters of a model from data by means
53
/// of minimization of a cost function. The value of the
54
/// objective function is the cost of the model predictions
55
/// on the training data, given the targets.
56
/// \par
57
/// It supports mini-batch learning using an optional fourth argument to
58
/// The constructor. With mini-batch learning enabled, each iteration a random
59
/// batch is taken from the dataset. Thus the size of the minibatch is the size of the batches in
60
/// the datasets. Normalization ensures that batches of different sizes have approximately the same
61
/// magnitude of error and derivative.
62
///
63
///\par
64
/// It automatically infers the input und label type from the given dataset and the output type
65
/// of the model in the constructor and ensures that Model and loss match. Thus the user does
66
/// not need to provide the types as template parameters.
67
class
ErrorFunction
:
public
SingleObjectiveFunction
68
{
69
public
:
70
template
<
class
InputType,
class
LabelType,
class
OutputType>
71
ErrorFunction
(
72
LabeledData<InputType, LabelType>
const
& dataset,
73
AbstractModel<InputType,OutputType>
* model,
74
AbstractLoss<LabelType, OutputType>
* loss,
75
bool
useMiniBatches =
false
76
);
77
template
<
class
InputType,
class
LabelType,
class
OutputType>
78
ErrorFunction
(
79
WeightedLabeledData<InputType, LabelType>
const
& dataset,
80
AbstractModel<InputType,OutputType>
* model,
81
AbstractLoss<LabelType, OutputType>
* loss
82
);
83
ErrorFunction
(
const
ErrorFunction
& op);
84
ErrorFunction
&
operator=
(
const
ErrorFunction
& op);
85
86
std::string
name
()
const
87
{
return
"ErrorFunction"
; }
88
89
void
setRegularizer
(
double
factor,
SingleObjectiveFunction
* regularizer){
90
m_regularizer = regularizer;
91
m_regularizationStrength = factor;
92
}
93
94
SearchPointType
proposeStartingPoint
()
const
{
95
return
mp_wrapper ->
proposeStartingPoint
();
96
}
97
std::size_t
numberOfVariables
()
const
{
98
return
mp_wrapper ->
numberOfVariables
();
99
}
100
101
void
init
(){
102
mp_wrapper->setRng(this->
mep_rng
);
103
mp_wrapper->
init
();
104
}
105
106
double
eval
(RealVector
const
& input)
const
;
107
ResultType
evalDerivative
(
const
SearchPointType
& input,
FirstOrderDerivative
& derivative )
const
;
108
109
friend
void
swap
(
ErrorFunction
& op1,
ErrorFunction
& op2);
110
111
private
:
112
boost::scoped_ptr<detail::FunctionWrapperBase > mp_wrapper;
113
SingleObjectiveFunction
* m_regularizer;
114
double
m_regularizationStrength;
115
};
116
117
}
118
#include "Impl/ErrorFunction.inl"
119
#endif