MeanModel.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Implements the Mean Model that can be used for ensemble classifiers
6  *
7  *
8  *
9  * \author Kang Li, O. Krause
10  * \date 2014
11  *
12  *
13  * \par Copyright 1995-2017 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://shark-ml.org/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_MODELS_MEANMODEL_H
36 #define SHARK_MODELS_MEANMODEL_H
37 
38 namespace shark {
39 /// \brief Calculates the weighted mean of a set of models
40 template<class ModelType>
41 class MeanModel : public AbstractModel<typename ModelType::InputType, RealVector, typename ModelType::ParameterVectorType>
42 {
43 private:
44  template<class T> struct tag{};
45 
46  template<class InputBatch>
47  void doEval(InputBatch const& patterns, RealMatrix& outputs, tag<RealVector>)const{
48  for(std::size_t i = 0; i != m_models.size(); i++)
49  noalias(outputs) += m_weight[i] * m_models[i](patterns);
50  outputs /= m_weightSum;
51  }
52  template<class InputBatch>
53  void doEval(InputBatch const& patterns, RealMatrix& outputs, tag<unsigned int>)const{
54  blas::vector<unsigned int> responses;
55  for(std::size_t i = 0; i != m_models.size(); ++i){
56  m_models[i].eval(patterns, responses);
57  for(std::size_t p = 0; p != patterns.size1(); ++p){
58  SIZE_CHECK(responses(p) < m_outputDim);
59  outputs(p,responses(p)) += m_weight[i];
60  }
61  }
62  outputs /= m_weightSum;
63  }
65 public:
66 
67 
71  /// Constructor
73 
74  std::string name() const
75  { return "MeanModel"; }
76 
77  ///\brief Returns the expected shape of the input
78  Shape inputShape() const{
79  return m_models.empty() ? Shape(): m_models.front().inputShape();
80  }
81  ///\brief Returns the shape of the output
82  Shape outputShape() const{
83  return m_models.empty() ? Shape(): m_models.front().outputShape();
84  }
85 
86  using ModelBaseType::eval;
87  void eval(BatchInputType const& patterns, BatchOutputType& outputs)const{
88  outputs.resize(patterns.size1(), m_outputDim);
89  outputs.clear();
90  doEval(patterns,outputs, tag<typename ModelType::OutputType>());
91  }
92 
93  void eval(BatchInputType const& patterns, BatchOutputType& outputs, State& state)const{
94  eval(patterns,outputs);
95  }
96 
97  std::size_t outputSize() const{
98  return m_outputDim;
99  }
100 
101 
102  /// This model does not have any parameters.
103  ParameterVectorType parameterVector() const {
104  return {};
105  }
106 
107  /// This model does not have any parameters
108  void setParameterVector(ParameterVectorType const& param) {
109  SHARK_ASSERT(param.size() == 0);
110  }
111  void read(InArchive& archive){
112  archive >> m_models;
113  archive >> m_weight;
114  archive >> m_weightSum;
115  archive >> m_outputDim;
116  }
117  void write(OutArchive& archive)const{
118  archive << m_models;
119  archive << m_weight;
120  archive << m_weightSum;
121  archive << m_outputDim;
122  }
123 
124  /// \brief Removes all models from the ensemble
125  void clearModels(){
126  m_models.clear();
127  m_weight.clear();
128  m_weightSum = 0.0;
129  }
130 
131  /// \brief Adds a new model to the ensemble.
132  ///
133  /// \param model the new model
134  /// \param weight weight of the model. must be > 0
135  void addModel(ModelType const& model, double weight = 1.0){
136  SHARK_RUNTIME_CHECK(weight > 0, "Weights must be positive");
137  m_models.push_back(model);
138  m_weight.push_back(weight);
139  m_weightSum += weight;
140  }
141 
142  ModelType const& getModel(std::size_t index)const{
143  return m_models[index];
144  }
145 
146  /// \brief Returns the weight of the i-th model
147  double const& weight(std::size_t i)const{
148  return m_weight[i];
149  }
150 
151  /// \brief sets the weight of the i-th model
152  void setWeight(std::size_t i, double newWeight){
153  m_weightSum += newWeight - m_weight[i];
154  m_weight[i] = newWeight;
155  }
156 
157  ///\brief sets the dimensionality of the output
158  void setOutputSize(std::size_t dim){
159  m_outputDim = dim;
160  }
161 
162  /// \brief Returns the number of models.
163  std::size_t numberOfModels()const{
164  return m_models.size();
165  }
166 
167 protected:
168  /// \brief collection of models.
169  std::vector<ModelType> m_models;
170 
171  /// \brief Weight of the mean.
172  RealVector m_weight;
173 
174  /// \brief Total sum of weights.
175  double m_weightSum;
176 
177  ///\brief output dimensionality
178  std::size_t m_outputDim;
179 };
180 
181 
182 }
183 #endif // SHARK_MODELS_MEANMODEL_H