RFClassifier.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Random Forest Classifier.
6  *
7  *
8  *
9  * \author K. N. Hansen, O.Krause, J. Kremer
10  * \date 2011-2012
11  *
12  *
13  * \par Copyright 1995-2017 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://shark-ml.org/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_MODELS_TREES_RFCLASSIFIER_H
36 #define SHARK_MODELS_TREES_RFCLASSIFIER_H
37 
39 #include <shark/Models/MeanModel.h>
43 #include <shark/Data/DataView.h>
44 
45 namespace shark {
46 
47 namespace detail{
48 //this class bridges the differences between random forests in classification and regression
49 template<class LabelType>
50 class RFClassifierBase : public MeanModel<CARTree<LabelType> >{
51 protected:
52  double doComputeOOBerror(
53  UIntMatrix const& oobPoints, LabeledData<RealVector, RealVector> const& data
54  ){
55  double OOBerror = 0;
56  //aquire votes for every element
57  RealVector mean(labelDimension(data));
58  RealVector input(inputDimension(data));
59  std::size_t elem = 0;
60  for(auto const& point: data.elements()){
61  noalias(input) = point.input;
62  mean.clear();
63  std::size_t oobModels = 0;
64  for(std::size_t m = 0; m != this->numberOfModels();++m){
65  if(oobPoints(m,elem)){
66  ++oobModels;
67  auto const& model = this->getModel(m);
68  noalias(mean) += model(input);
69  }
70  }
71  mean /= oobModels;
72  OOBerror += 0.5 * norm_sqr(point.label - mean);
73  ++elem;
74  }
75  OOBerror /= data.numberOfElements();
76  return OOBerror;
77  }
78 
79  double loss(RealMatrix const& labels, RealMatrix const& predictions) const{
80  SquaredLoss<RealVector, RealVector> loss;
81  return loss.eval(labels, predictions);
82  }
83 };
84 
85 template<>
86 class RFClassifierBase<unsigned int> : public Classifier<MeanModel<CARTree<unsigned int> > >{
87 public:
88  //make the interface of MeanModel publicly available for same basic interface for classification and regression case
89  CARTree<unsigned int> const& getModel(std::size_t index)const{
90  return this->decisionFunction().getModel(index);
91  }
92 
93  void addModel(CARTree<unsigned int> const& model, double weight = 1.0){
94  this->decisionFunction().addModel(model,weight);
95  }
96  void clearModels(){
97  this->decisionFunction().clearModels();
98  }
99 
100  void setOutputSize(std::size_t dim){
101  this->decisionFunction().setOutputSize(dim);
102  }
103 
104  /// \brief Returns the number of models.
105  std::size_t numberOfModels()const{
106  return this->decisionFunction().numberOfModels();
107  }
108 protected:
109  double loss(UIntVector const& labels, UIntVector const& predictions) const{
110  ZeroOneLoss<unsigned int> loss;
111  return loss.eval(labels, predictions);
112  }
113 
114  double doComputeOOBerror(
115  UIntMatrix const& oobPoints, LabeledData<RealVector, unsigned int> const& data
116  ){
117  double OOBerror = 0;
118  //aquire votes for every element
119  RealVector votes(numberOfClasses(data));
120  RealVector input(inputDimension(data));
121  std::size_t elem = 0;
122  for(auto const& point: data.elements()){
123  noalias(input) = point.input;
124  votes.clear();
125  for(std::size_t m = 0; m != numberOfModels();++m){
126  if(oobPoints(m,elem)){
127  auto const& model = getModel(m);
128  unsigned int label = model(input);
129  votes(label) += 1;
130  }
131  }
132  OOBerror += (arg_max(votes) != point.label);
133  ++elem;
134  }
135  OOBerror /= data.numberOfElements();
136  return OOBerror;
137  }
138 };
139 }
140 
141 ///
142 /// \brief Random Forest Classifier.
143 ///
144 /// \par
145 /// The Random Forest Classifier predicts a class label
146 /// using the Random Forest algorithm as described in<br/>
147 /// Random Forests. Leo Breiman. Machine Learning, 1(45), pages 5-32. Springer, 2001.<br/>
148 ///
149 /// \par
150 /// It is an ensemble learner that uses multiple decision trees built
151 /// using the CART methodology.
152 ///
153 template<class LabelType>
154 class RFClassifier : public detail::RFClassifierBase<LabelType>
155 {
156 public:
157  /// \brief From INameable: return the class name.
158  std::string name() const
159  { return "RFClassifier"; }
160 
161 
162  /// \brief Returns the computed out-of-bag-error of the forest
163  double OOBerror() const {
164  return m_OOBerror;
165  }
166 
167  /// \brief Returns the computed feature importances of the forest
168  RealVector const& featureImportances()const{
169  return m_featureImportances;
170  }
171 
172  /// \brief Counts how often attributes are used
173  UIntVector countAttributes() const {
174  std::size_t n = this->numberOfModels();
175  if(!n) return UIntVector();
176  UIntVector r = this->getModel(0).countAttributes();
177  for(std::size_t i=1; i< n; i++ ) {
178  noalias(r) += this->getModel(i).countAttributes();
179  }
180  return r;
181  }
182 
183  /// Compute oob error, given an oob dataset (Classification)
184  void computeOOBerror(std::vector<std::vector<std::size_t> > const& oobIndices, LabeledData<RealVector, LabelType> const& data){
185  UIntMatrix oobMatrix(oobIndices.size(), data.numberOfElements(),0);
186  for(std::size_t i = 0; i != oobMatrix.size1(); ++i){
187  for(auto index: oobIndices[i])
188  oobMatrix(i,index) = 1;
189  }
190  m_OOBerror = this->doComputeOOBerror(oobMatrix,data);
191  }
192 
193  /// Compute feature importances, given an oob dataset
194  ///
195  /// For each tree, extracts the out-of-bag-samples indicated by oobIndices. The feature importance is defined
196  /// as the average change of loss (Squared loss or accuracy depending on label type) when the feature is permuted across the oob samples of a tree.
197  void computeFeatureImportances(std::vector<std::vector<std::size_t> > const& oobIndices, LabeledData<RealVector, LabelType> const& data, random::rng_type& rng){
198  std::size_t inputs = inputDimension(data);
199  m_featureImportances.resize(inputs);
201 
202  for(std::size_t m = 0; m != this->numberOfModels();++m){
203  auto batch = subBatch(view, oobIndices[m]);
204  double errorBefore = this->loss(batch.label,this->getModel(m)(batch.input));
205 
206  for(std::size_t i=0; i!=inputs;++i) {
207  RealVector vOld= column(batch.input,i);
208  RealVector v = vOld;
209  std::shuffle(v.begin(), v.end(), rng);
210  noalias(column(batch.input,i)) = v;
211  double errorAfter = this->loss(batch.label,this->getModel(m)(batch.input));
212  noalias(column(batch.input,i)) = vOld;
213  m_featureImportances(i) += (errorAfter - errorBefore) / batch.size();
214  }
215  }
216  m_featureImportances /= this->numberOfModels();
217  }
218 
219 private:
220  double m_OOBerror; ///< oob error for the forest
221  RealVector m_featureImportances; ///< feature importances for the forest
222 
223 };
224 
225 
226 }
227 #endif