Adam.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Adam
6  *
7  *
8  *
9  * \author O. Krause
10  * \date 2017
11  *
12  *
13  * \par Copyright 1995-2017 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://shark-ml.org/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 #ifndef SHARK_ML_OPTIMIZER_ADAM_H
35 #define SHARK_ML_OPTIMIZER_ADAM_H
36 
38 
39 namespace shark{
40 
41 ///@brief Adaptive Moment Estimation Algorithm (ADAM)
42 ///
43 /// Performs SGD by using a long term average of the gradient as well as its second moment to adapt
44 /// a step size for each coordinate.
45 class Adam : public AbstractSingleObjectiveOptimizer<RealVector >
46 {
47 public:
48  Adam() {
50 
51  m_beta1 = 0.9;
52  m_beta2 = 0.999;
53  m_epsilon = 1.e-8;
54  m_eta = 0.001;
55  }
56 
57  /// \brief From INameable: return the class name.
58  std::string name() const
59  { return "Adam"; }
60 
61  void init(ObjectiveFunctionType const& objectiveFunction, SearchPointType const& startingPoint) {
62  checkFeatures(objectiveFunction);
63  SHARK_RUNTIME_CHECK(startingPoint.size() == objectiveFunction.numberOfVariables(), "Initial starting point and dimensionality of function do not agree");
64 
65  //initialize long term averages
66  m_avgGrad = blas::repeat(0.0,startingPoint.size());
67  m_secondMoment = blas::repeat(0.0,startingPoint.size());
68  m_counter = 0;
69 
70  //set point to the current starting point
71  m_best.point = startingPoint;
72  m_best.value = objectiveFunction.evalDerivative(m_best.point,m_derivative);
73  }
75 
76  /// \brief get learning rate eta
77  double eta() const {
78  return m_eta;
79  }
80 
81  /// \brief set learning rate eta
82  void setEta(double eta) {
83  SHARK_RUNTIME_CHECK(eta > 0, "eta must be positive.");
84  m_eta = eta;
85  }
86 
87  /// \brief get gradient averaging parameter beta1
88  double beta1() const {
89  return m_beta1;
90  }
91 
92  /// \brief set gradient averaging parameter beta1
93  void setBeta1(double beta1) {
94  SHARK_RUNTIME_CHECK(beta1 > 0, "beta1 must be positive.");
95  m_beta1 = beta1;
96  }
97 
98  /// \brief get gradient averaging parameter beta2
99  double beta2() const {
100  return m_beta2;
101  }
102 
103  /// \brief set gradient averaging parameter beta2
104  void setBeta2(double beta2) {
105  SHARK_RUNTIME_CHECK(beta2 > 0, "beta2 must be positive.");
106  m_beta2 = beta2;
107  }
108 
109  /// \brief get minimum noise estimate epsilon
110  double epsilon() const {
111  return m_epsilon;
112  }
113 
114  /// \brief set minimum noise estimate epsilon
115  void setEpsilon(double epsilon) {
116  SHARK_RUNTIME_CHECK(epsilon > 0, "epsilon must be positive.");
117  m_epsilon = epsilon;
118  }
119  /// \brief Performs a step of the optimization.
120  ///
121  /// First the current guess for gradient and its second moment are updated using
122  /// \f[ g_t = \beta_1 g_{t-1} + (1-\beta1) \frac{\partial}{\partial x} f(x_{t-1})\f]
123  /// \f[ v_t = \beta_2 v_{t-1} + (1-\beta2) (\frac{\partial}{\partial x} f(x_{t-1}))^2\f]
124  ///
125  /// The step is then performed as
126  /// \f[ x_{t} = x_{t-1} - \eta * g_t *(sqrt(v_t) + \epsilon)^{-1} \f]
127  /// where a slight step correction is used to remove the bias in the first few iterations where the means are close to 0.
128  void step(ObjectiveFunctionType const& objectiveFunction) {
129  //update long term averages of the gradient and its variance
130  noalias(m_avgGrad) = m_beta1 * m_avgGrad + (1-m_beta1) * m_derivative;
131  noalias(m_secondMoment) = m_beta2 * m_secondMoment + (1-m_beta2)* sqr(m_derivative);
132  //for the first few iterations, we need bias correction
133  ++m_counter;
134  double bias1 = 1-std::pow(m_beta1,m_counter);
135  double bias2 = 1-std::pow(m_beta2,m_counter);
136  //~ std::cout<<"m "<<m_avgGrad<<std::endl;
137  //~ std::cout<<"v "<<m_secondMoment<<std::endl;
138 
139  noalias(m_best.point) -= (m_eta/bias1) * m_avgGrad/(m_epsilon + sqrt(m_secondMoment/bias2));
140  m_best.value = objectiveFunction.evalDerivative(m_best.point,m_derivative);
141  }
142  virtual void read( InArchive & archive )
143  {
144  archive>>m_avgGrad;
145  archive>>m_secondMoment;
146  archive>>m_counter;
147  archive>>m_derivative;
148  archive>>m_best;
149 
150  archive>>m_beta1;
151  archive>>m_beta2;
152  archive>>m_epsilon;
153  archive>>m_eta;
154  }
155 
156  virtual void write( OutArchive & archive ) const
157  {
158  archive<<m_avgGrad;
159  archive<<m_secondMoment;
160  archive<<m_counter;
161  archive<<m_derivative;
162  archive<<m_best;
163 
164  archive<<m_beta1;
165  archive<<m_beta2;
166  archive<<m_epsilon;
167  archive<<m_eta;
168  }
169 
170 private:
171  RealVector m_avgGrad;
172  RealVector m_secondMoment;
173  unsigned int m_counter;
175 
176  double m_beta1;
177  double m_beta2;
178  double m_epsilon;
179  double m_eta;
180 };
181 
182 }
183 #endif
184