Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
include
shark
Algorithms
Trainers
OptimizationTrainer.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief Model training by means of a general purpose optimization procedure.
6
*
7
*
8
*
9
* \author T. Glasmachers
10
* \date 2011-2012
11
*
12
*
13
* \par Copyright 1995-2017 Shark Development Team
14
*
15
* <BR><HR>
16
* This file is part of Shark.
17
* <http://shark-ml.org/>
18
*
19
* Shark is free software: you can redistribute it and/or modify
20
* it under the terms of the GNU Lesser General Public License as published
21
* by the Free Software Foundation, either version 3 of the License, or
22
* (at your option) any later version.
23
*
24
* Shark is distributed in the hope that it will be useful,
25
* but WITHOUT ANY WARRANTY; without even the implied warranty of
26
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27
* GNU Lesser General Public License for more details.
28
*
29
* You should have received a copy of the GNU Lesser General Public License
30
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
31
*
32
*/
33
//===========================================================================
34
35
#ifndef SHARK_ALGORITHMS_TRAINERS_OPTIMIZATIONTRAINER_H
36
#define SHARK_ALGORITHMS_TRAINERS_OPTIMIZATIONTRAINER_H
37
38
#include <
shark/Algorithms/AbstractSingleObjectiveOptimizer.h
>
39
#include <
shark/Core/ResultSets.h
>
40
#include <
shark/Models/AbstractModel.h
>
41
#include <
shark/ObjectiveFunctions/ErrorFunction.h
>
42
#include <
shark/Algorithms/Trainers/AbstractTrainer.h
>
43
#include <
shark/Algorithms/StoppingCriteria/AbstractStoppingCriterion.h
>
44
45
namespace
shark
{
46
47
48
///
49
/// \brief Wrapper for training schemes based on (iterative) optimization.
50
///
51
/// \par
52
/// The OptimizationTrainer class is designed to allow for
53
/// model training via iterative minimization of a
54
/// loss function, such as in neural network
55
/// "backpropagation" training.
56
///
57
template
<
class
Model,
class
LabelTypeT =
typename
Model::OutputType>
58
class
OptimizationTrainer
:
public
AbstractTrainer
<Model,LabelTypeT>
59
{
60
typedef
AbstractTrainer<Model,LabelTypeT>
base_type
;
61
62
public
:
63
typedef
typename
base_type::InputType
InputType
;
64
typedef
typename
base_type::LabelType
LabelType
;
65
typedef
typename
base_type::ModelType
ModelType
;
66
67
typedef
AbstractSingleObjectiveOptimizer< RealVector >
OptimizerType
;
68
typedef
AbstractLoss< LabelType, InputType >
LossType
;
69
typedef
AbstractStoppingCriterion<SingleObjectiveResultSet<OptimizerType::SearchPointType>
>
StoppingCriterionType
;
70
71
OptimizationTrainer
(
72
LossType* loss,
73
OptimizerType* optimizer,
74
StoppingCriterionType* stoppingCriterion)
75
:
mep_loss
(loss),
mep_optimizer
(optimizer),
mep_stoppingCriterion
(stoppingCriterion)
76
{
77
SHARK_RUNTIME_CHECK
(loss !=
nullptr
,
"Loss function must not be NULL"
);
78
SHARK_RUNTIME_CHECK
(optimizer !=
nullptr
,
"optimizer must not be NULL"
);
79
SHARK_RUNTIME_CHECK
(stoppingCriterion !=
nullptr
,
"Stopping Criterion must not be NULL"
);
80
}
81
82
/// \brief From INameable: return the class name.
83
std::string
name
()
const
84
{
85
return
"OptimizationTrainer<"
86
+
mep_loss
->
name
() +
","
87
+
mep_optimizer
->
name
() +
">"
;
88
}
89
90
void
train
(ModelType& model,
LabeledData<InputType, LabelType>
const
& dataset) {
91
ErrorFunction
error(dataset, &model,
mep_loss
);
92
error.
init
();
93
mep_optimizer
->
init
(error);
94
mep_stoppingCriterion
->
reset
();
95
do
{
96
mep_optimizer
->
step
(error);
97
}
98
while
(!
mep_stoppingCriterion
->
stop
(
mep_optimizer
->
solution
()));
99
model.setParameterVector(
mep_optimizer
->
solution
().
point
);
100
}
101
102
void
read
(
InArchive
& archive )
103
{}
104
105
void
write
(
OutArchive
& archive )
const
106
{}
107
108
protected
:
109
LossType*
mep_loss
;
110
OptimizerType*
mep_optimizer
;
111
StoppingCriterionType*
mep_stoppingCriterion
;
112
};
113
114
115
}
116
#endif