Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
include
shark
Algorithms
GradientDescent
SteepestDescent.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief SteepestDescent
6
*
7
*
8
*
9
* \author O. Krause
10
* \date 2010
11
*
12
*
13
* \par Copyright 1995-2017 Shark Development Team
14
*
15
* <BR><HR>
16
* This file is part of Shark.
17
* <http://shark-ml.org/>
18
*
19
* Shark is free software: you can redistribute it and/or modify
20
* it under the terms of the GNU Lesser General Public License as published
21
* by the Free Software Foundation, either version 3 of the License, or
22
* (at your option) any later version.
23
*
24
* Shark is distributed in the hope that it will be useful,
25
* but WITHOUT ANY WARRANTY; without even the implied warranty of
26
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27
* GNU Lesser General Public License for more details.
28
*
29
* You should have received a copy of the GNU Lesser General Public License
30
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
31
*
32
*/
33
//===========================================================================
34
#ifndef SHARK_ML_OPTIMIZER_STEEPESTDESCENT_H
35
#define SHARK_ML_OPTIMIZER_STEEPESTDESCENT_H
36
37
#include <
shark/Algorithms/AbstractSingleObjectiveOptimizer.h
>
38
39
namespace
shark
{
40
41
///@brief Standard steepest descent.
42
class
SteepestDescent
:
public
AbstractSingleObjectiveOptimizer
<RealVector >
43
{
44
public
:
45
SteepestDescent
() {
46
m_features
|=
REQUIRES_FIRST_DERIVATIVE
;
47
48
m_learningRate = 0.1;
49
m_momentum = 0.0;
50
}
51
52
/// \brief From INameable: return the class name.
53
std::string
name
()
const
54
{
return
"SteepestDescent"
; }
55
56
void
init
(
ObjectiveFunctionType
const
& objectiveFunction,
SearchPointType
const
& startingPoint) {
57
checkFeatures
(objectiveFunction);
58
SHARK_RUNTIME_CHECK
(startingPoint.size() == objectiveFunction.
numberOfVariables
(),
"Initial starting point and dimensionality of function do not agree"
);
59
60
m_path.resize(startingPoint.size());
61
m_path.clear();
62
m_best
.
point
= startingPoint;
63
m_best
.
value
= objectiveFunction.
evalDerivative
(
m_best
.
point
,m_derivative);
64
}
65
using
AbstractSingleObjectiveOptimizer<RealVector >::init
;
66
67
/*!
68
* \brief get learning rate
69
*/
70
double
learningRate
()
const
{
71
return
m_learningRate;
72
}
73
74
/*!
75
* \brief set learning rate
76
*/
77
void
setLearningRate
(
double
learningRate
) {
78
m_learningRate =
learningRate
;
79
}
80
81
/*!
82
* \brief get momentum parameter
83
*/
84
double
momentum
()
const
{
85
return
m_momentum;
86
}
87
88
/*!
89
* \brief set momentum parameter
90
*/
91
void
setMomentum
(
double
momentum
) {
92
m_momentum =
momentum
;
93
}
94
/*!
95
* \brief updates searchdirection and then does simple gradient descent
96
*/
97
void
step
(
ObjectiveFunctionType
const
& objectiveFunction) {
98
m_path = -m_learningRate * m_derivative + m_momentum * m_path;
99
m_best
.
point
+=m_path;
100
m_best
.
value
= objectiveFunction.
evalDerivative
(
m_best
.
point
,m_derivative);
101
}
102
virtual
void
read
(
InArchive
& archive )
103
{
104
archive>>m_path;
105
archive>>m_learningRate;
106
archive>>m_momentum;
107
}
108
109
virtual
void
write
(
OutArchive
& archive )
const
110
{
111
archive<<m_path;
112
archive<<m_learningRate;
113
archive<<m_momentum;
114
}
115
116
private
:
117
RealVector m_path;
118
ObjectiveFunctionType::FirstOrderDerivative
m_derivative;
119
double
m_learningRate;
120
double
m_momentum;
121
};
122
123
}
124
#endif
125