OptimizationTrainer.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Model training by means of a general purpose optimization procedure.
6  *
7  *
8  *
9  * \author T. Glasmachers
10  * \date 2011-2012
11  *
12  *
13  * \par Copyright 1995-2017 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://shark-ml.org/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_ALGORITHMS_TRAINERS_OPTIMIZATIONTRAINER_H
36 #define SHARK_ALGORITHMS_TRAINERS_OPTIMIZATIONTRAINER_H
37 
39 #include <shark/Core/ResultSets.h>
44 
45 namespace shark {
46 
47 
48 ///
49 /// \brief Wrapper for training schemes based on (iterative) optimization.
50 ///
51 /// \par
52 /// The OptimizationTrainer class is designed to allow for
53 /// model training via iterative minimization of a
54 /// loss function, such as in neural network
55 /// "backpropagation" training.
56 ///
57 template <class Model, class LabelTypeT = typename Model::OutputType>
58 class OptimizationTrainer : public AbstractTrainer<Model,LabelTypeT>
59 {
61 
62 public:
63  typedef typename base_type::InputType InputType;
64  typedef typename base_type::LabelType LabelType;
65  typedef typename base_type::ModelType ModelType;
66 
70 
72  LossType* loss,
73  OptimizerType* optimizer,
74  StoppingCriterionType* stoppingCriterion)
75  : mep_loss(loss), mep_optimizer(optimizer), mep_stoppingCriterion(stoppingCriterion)
76  {
77  SHARK_RUNTIME_CHECK(loss != nullptr, "Loss function must not be NULL");
78  SHARK_RUNTIME_CHECK(optimizer != nullptr, "optimizer must not be NULL");
79  SHARK_RUNTIME_CHECK(stoppingCriterion != nullptr, "Stopping Criterion must not be NULL");
80  }
81 
82  /// \brief From INameable: return the class name.
83  std::string name() const
84  {
85  return "OptimizationTrainer<"
86  + mep_loss->name() + ","
87  + mep_optimizer->name() + ">";
88  }
89 
90  void train(ModelType& model, LabeledData<InputType, LabelType> const& dataset) {
91  ErrorFunction error(dataset, &model, mep_loss);
92  error.init();
93  mep_optimizer->init(error);
95  do {
96  mep_optimizer->step(error);
97  }
99  model.setParameterVector(mep_optimizer->solution().point);
100  }
101 
102  void read( InArchive & archive )
103  {}
104 
105  void write( OutArchive & archive ) const
106  {}
107 
108 protected:
109  LossType* mep_loss;
110  OptimizerType* mep_optimizer;
111  StoppingCriterionType* mep_stoppingCriterion;
112 };
113 
114 
115 }
116 #endif