RFTrainer.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Random Forest Trainer
6  *
7  *
8  *
9  * \author K. N. Hansen, J. Kremer
10  * \date 2011-2012
11  *
12  *
13  * \par Copyright 1995-2017 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://shark-ml.org/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 
36 #ifndef SHARK_ALGORITHMS_TRAINERS_RFTRAINER_H
37 #define SHARK_ALGORITHMS_TRAINERS_RFTRAINER_H
38 
41 #include <shark/Algorithms/Trainers/Impl/CART.h>
42 
43 #include <vector>
44 #include <limits>
45 
46 namespace shark {
47 /*!
48  * \brief Random Forest
49  *
50  * Random Forest is an ensemble learner, that builds multiple binary decision trees.
51  * The trees are built using a variant of the CART methodology
52  *
53  * Typically 100+ trees are built, and classification/regression is done by combining
54  * the results generated by each tree. Typically the a majority vote is used in the
55  * classification case, and the mean is used in the regression case
56  *
57  * Each tree is built based on a random subset of the total dataset. Furthermore
58  * at each split, only a random subset of the attributes are investigated for
59  * the best split
60  *
61  * The node impurity is measured by the Gini criteria in the classification
62  * case, and the total sum of squared errors in the regression case
63  *
64  * After growing a maximum sized tree, the tree is added to the ensemble
65  * without pruning.
66  *
67  * For detailed information about Random Forest, see Random Forest
68  * by L. Breiman et al. 2001.
69  */
70 
71 template<class LabelType>
72 class RFTrainer;
73 
74 template<>
75 class RFTrainer<unsigned int>
76 : public AbstractWeightedTrainer<RFClassifier<unsigned int> >, public IParameterizable<RealVector>
77 {
78 public:
79  /// Construct and compute feature importances when training or not
80  RFTrainer(bool computeFeatureImportances = false, bool computeOOBerror = false){
81  m_computeFeatureImportances = computeFeatureImportances;
82  m_computeOOBerror = computeOOBerror;
83  m_numTrees = 100;
84  m_min_samples_leaf = 1;
85  m_min_split = 2 * m_min_samples_leaf;
86  m_max_depth = 10000;
87  m_min_impurity_split = 1e-10;
88  m_epsilon = 1e-10;
89  m_max_features = 0;
90  }
91 
92  /// \brief From INameable: return the class name.
93  std::string name() const
94  { return "RFTrainer"; }
95 
96  /// Set the number of random attributes to investigate at each node.
97  ///
98  /// Defualt is 0 which is translated to sqrt(inputDim(data)) during training
99  void setMTry(std::size_t mtry) { m_max_features = mtry; }
100 
101  /// Set the number of trees to grow. (default 100)
102  void setNTrees(std::size_t numTrees) {m_numTrees = numTrees;}
103 
104  /// Set Minimum number of samples that is split (default 2)
105  void setMinSplit(std::size_t numSamples) {m_min_split = numSamples;}
106 
107  /// Set Maximum depth of the tree (default 10000)
108  void setMaxDepth(std::size_t maxDepth) {m_max_depth = maxDepth;}
109 
110  /// Controls when a node is considered pure. If set to 1, a node is pure
111  /// when it only consists of a single node.(default 1)
112  void setNodeSize(std::size_t nodeSize) { m_min_samples_leaf = nodeSize; }
113 
114  /// The minimum impurity below which a a node is considere pure (default 1.e-10)
115  void minImpurity(double impurity) {m_min_impurity_split = impurity;}
116 
117  /// The minimum dtsnace of features to be considered different (detault 1.e-10)
118  void epsilon(double distance) {m_epsilon = distance;}
119 
120  /// Return the parameter vector.
121  RealVector parameterVector() const{return RealVector();}
122 
123  /// Set the parameter vector.
124  void setParameterVector(RealVector const& newParameters){
125  SHARK_ASSERT(newParameters.size() == 0);
126  }
127 
128 
129  /// Train a random forest for classification.
132  model.clearModels();
133  model.setOutputSize(numberOfClasses(dataset));
134 
135  //setup treebuilder
136  CART::TreeBuilder<unsigned int,CART::ClassificationCriterion> builder;
137  builder.m_min_samples_leaf = m_min_samples_leaf;
138  builder.m_min_split = m_min_split;
139  builder.m_max_depth = m_max_depth;
140  builder.m_min_impurity_split = m_min_impurity_split;
141  builder.m_epsilon = m_epsilon;
142  builder.m_max_features = m_max_features? m_max_features: std::sqrt(inputDimension(dataset));
143 
144  //copy data into single batch for easier lookup
145  blas::matrix<double, blas::column_major> data_train = createBatch<RealVector>(dataset.inputs().elements().begin(),dataset.inputs().elements().end());
146  auto labels_train = createBatch<LabelType>(dataset.labels().elements().begin(),dataset.labels().elements().end());
147  auto weights_train = createBatch<double>(dataset.weights().elements().begin(),dataset.weights().elements().end());
148 
149  //Setup seeds for the rng in the different threads
150  std::vector<unsigned int> seeds(m_numTrees);
151  for (auto& seed: seeds) {
152  seed = random::discrete(random::globalRng, 0u,std::numeric_limits<unsigned int>::max());
153  }
154 
155  std::vector<std::vector<std::size_t> > complements;
156 
157  //Generate trees
158  SHARK_PARALLEL_FOR(int t = 0; t < m_numTrees; ++t){
159  random::rng_type rng(seeds[t]);
160 
161  //Setup data for this tree
162  CART::Bootstrap<blas::matrix<double, blas::column_major>, UIntVector> bootstrap(rng, data_train,labels_train, weights_train);
163  auto const& tree = builder.buildTree(rng, bootstrap);
164 
166  model.addModel(tree);
167  complements.push_back(std::move(bootstrap.complement));
168  }
169  }
170 
171  if(m_computeOOBerror)
172  model.computeOOBerror(complements, dataset.data());
173 
174  if(m_computeFeatureImportances)
175  model.computeFeatureImportances(complements,dataset.data(), random::globalRng);
176  }
177 
178 
179 private:
180  bool m_computeFeatureImportances;///< set true if the feature importances should be computed
181  bool m_computeOOBerror;///< set true if OOB error should be computed
182 
183  long m_numTrees; ///< number of trees in the forest
184  std::size_t m_max_features;///< number of attributes to randomly test at each inner node
185  std::size_t m_min_samples_leaf; ///< minimum number of samples in a leaf node
186  std::size_t m_min_split; ///< minimum number of samples to be considered a split
187  std::size_t m_max_depth;///< maximum depth of the tree
188  double m_epsilon;///< Minimum difference between two values to be considered different
189  double m_min_impurity_split;///< stops splitting when the impority is below a threshold
190 };
191 
192 
193 template<>
194 class RFTrainer<RealVector>
195 : public AbstractWeightedTrainer<RFClassifier<RealVector> >, public IParameterizable<RealVector>
196 {
197 public:
198  /// Construct and compute feature importances when training or not
199  RFTrainer(bool computeFeatureImportances = false, bool computeOOBerror = false){
200  m_computeFeatureImportances = computeFeatureImportances;
201  m_computeOOBerror = computeOOBerror;
202  m_numTrees = 100;
203  m_min_samples_leaf = 1;
204  m_min_split = 2 * m_min_samples_leaf;
205  m_max_depth = 10000;
206  m_min_impurity_split = 1e-10;
207  m_epsilon = 1e-10;
208  m_max_features = 0;
209  }
210 
211  /// \brief From INameable: return the class name.
212  std::string name() const
213  { return "RFTrainer"; }
214 
215  /// Set the number of random attributes to investigate at each node.
216  ///
217  /// Defualt is 0 which is translated to inputDim(data)/3 during training
218  void setMTry(std::size_t mtry) { m_max_features = mtry; }
219 
220  /// Set the number of trees to grow. (default 100)
221  void setNTrees(std::size_t numTrees) {m_numTrees = numTrees;}
222 
223  /// Set Minimum number of samples that is split (default 10)
224  void setMinSplit(std::size_t numSamples) {m_min_split = numSamples;}
225 
226  /// Set Maximum depth of the tree (default 10000)
227  void setMaxDepth(std::size_t maxDepth) {m_max_depth = maxDepth;}
228 
229  /// Controls when a node is considered pure. If set to 1, a node is pure
230  /// when it only consists of a single node.(default 5)
231  void setNodeSize(std::size_t nodeSize) { m_min_samples_leaf = nodeSize; }
232 
233  /// The minimum impurity below which a a node is considere pure (default 1.e-10)
234  void minImpurity(double impurity) {m_min_impurity_split = impurity;}
235 
236  /// The minimum dtsnace of features to be considered different (detault 1.e-10)
237  void epsilon(double distance) {m_epsilon = distance;}
238 
239  /// Return the parameter vector.
240  RealVector parameterVector() const{ return RealVector();}
241 
242  /// Set the parameter vector.
243  void setParameterVector(RealVector const& newParameters){
244  SHARK_ASSERT(newParameters.size() == 0);
245  }
246 
247 
248  /// Train a random forest for classification.
250  model.clearModels();
251  model.setOutputSize(labelDimension(dataset));
252  //setup treebuilder
253  CART::TreeBuilder<RealVector,CART::MSECriterion> builder;
254  builder.m_min_samples_leaf = m_min_samples_leaf;
255  builder.m_min_split = m_min_split;
256  builder.m_max_depth = m_max_depth;
257  builder.m_min_impurity_split = m_min_impurity_split;
258  builder.m_epsilon = m_epsilon;
259  builder.m_max_features = m_max_features? m_max_features: inputDimension(dataset)/3;
260  //copy data into single batch for easier lookup
261  blas::matrix<double, blas::column_major> data_train = createBatch<RealVector>(dataset.inputs().elements().begin(),dataset.inputs().elements().end());
262  auto labels_train = createBatch<LabelType>(dataset.labels().elements().begin(),dataset.labels().elements().end());
263  auto weights_train = createBatch<double>(dataset.weights().elements().begin(),dataset.weights().elements().end());
264 
265  //Setup seeds for the rng in the different threads
266  std::vector<unsigned int> seeds(m_numTrees);
267  for (auto& seed: seeds) {
268  seed = random::discrete(random::globalRng, 0u,std::numeric_limits<unsigned int>::max());
269  }
270 
271  std::vector<std::vector<std::size_t> > complements;
272 
273  //Generate trees
274  SHARK_PARALLEL_FOR(int t = 0; t < m_numTrees; ++t){
275  random::rng_type rng{seeds[t]};
276 
277  //Setup data for this tree
278  CART::Bootstrap<blas::matrix<double, blas::column_major>, RealMatrix> bootstrap(rng, data_train,labels_train, weights_train);
279  auto const& tree = builder.buildTree(rng, bootstrap);
280 
282  model.addModel(tree);
283  complements.push_back(std::move(bootstrap.complement));
284  }
285  }
286 
287  if(m_computeOOBerror)
288  model.computeOOBerror(complements,dataset.data());
289 
290  if(m_computeFeatureImportances)
291  model.computeFeatureImportances(complements,dataset.data(), random::globalRng);
292  }
293 
294 
295 private:
296  bool m_computeFeatureImportances;///< set true if the feature importances should be computed
297  bool m_computeOOBerror;///< set true if OOB error should be computed
298 
299  long m_numTrees; ///< number of trees in the forest
300  std::size_t m_max_features;///< number of attributes to randomly test at each inner node
301  std::size_t m_min_samples_leaf; ///< minimum number of samples in a leaf node
302  std::size_t m_min_split; ///< minimum number of samples to be considered a split
303  std::size_t m_max_depth;///< maximum depth of the tree
304  double m_epsilon;///< Minimum difference between two values to be considered different
305  double m_min_impurity_split;///< stops splitting when the impority is below a threshold
306 };
307 
308 
309 }
310 #endif