36 #ifndef SHARK_ALGORITHMS_TRAINERS_RFTRAINER_H 37 #define SHARK_ALGORITHMS_TRAINERS_RFTRAINER_H 41 #include <shark/Algorithms/Trainers/Impl/CART.h> 71 template<
class LabelType>
80 RFTrainer(
bool computeFeatureImportances =
false,
bool computeOOBerror =
false){
81 m_computeFeatureImportances = computeFeatureImportances;
82 m_computeOOBerror = computeOOBerror;
84 m_min_samples_leaf = 1;
85 m_min_split = 2 * m_min_samples_leaf;
87 m_min_impurity_split = 1e-10;
94 {
return "RFTrainer"; }
99 void setMTry(std::size_t mtry) { m_max_features = mtry; }
102 void setNTrees(std::size_t numTrees) {m_numTrees = numTrees;}
105 void setMinSplit(std::size_t numSamples) {m_min_split = numSamples;}
112 void setNodeSize(std::size_t nodeSize) { m_min_samples_leaf = nodeSize; }
115 void minImpurity(
double impurity) {m_min_impurity_split = impurity;}
118 void epsilon(
double distance) {m_epsilon = distance;}
136 CART::TreeBuilder<unsigned int,CART::ClassificationCriterion> builder;
137 builder.m_min_samples_leaf = m_min_samples_leaf;
138 builder.m_min_split = m_min_split;
139 builder.m_max_depth = m_max_depth;
140 builder.m_min_impurity_split = m_min_impurity_split;
141 builder.m_epsilon = m_epsilon;
142 builder.m_max_features = m_max_features? m_max_features: std::sqrt(
inputDimension(dataset));
145 blas::matrix<double, blas::column_major> data_train = createBatch<RealVector>(dataset.
inputs().
elements().begin(),dataset.
inputs().
elements().end());
147 auto weights_train = createBatch<double>(dataset.weights().elements().begin(),dataset.weights().elements().end());
150 std::vector<unsigned int> seeds(m_numTrees);
151 for (
auto& seed: seeds) {
155 std::vector<std::vector<std::size_t> > complements;
159 random::rng_type rng(seeds[t]);
162 CART::Bootstrap<blas::matrix<double, blas::column_major>, UIntVector>
bootstrap(rng, data_train,labels_train, weights_train);
163 auto const& tree = builder.buildTree(rng, bootstrap);
167 complements.push_back(std::move(bootstrap.complement));
171 if(m_computeOOBerror)
174 if(m_computeFeatureImportances)
180 bool m_computeFeatureImportances;
181 bool m_computeOOBerror;
184 std::size_t m_max_features;
185 std::size_t m_min_samples_leaf;
186 std::size_t m_min_split;
187 std::size_t m_max_depth;
189 double m_min_impurity_split;
199 RFTrainer(
bool computeFeatureImportances =
false,
bool computeOOBerror =
false){
200 m_computeFeatureImportances = computeFeatureImportances;
201 m_computeOOBerror = computeOOBerror;
203 m_min_samples_leaf = 1;
204 m_min_split = 2 * m_min_samples_leaf;
206 m_min_impurity_split = 1e-10;
213 {
return "RFTrainer"; }
218 void setMTry(std::size_t mtry) { m_max_features = mtry; }
221 void setNTrees(std::size_t numTrees) {m_numTrees = numTrees;}
224 void setMinSplit(std::size_t numSamples) {m_min_split = numSamples;}
231 void setNodeSize(std::size_t nodeSize) { m_min_samples_leaf = nodeSize; }
234 void minImpurity(
double impurity) {m_min_impurity_split = impurity;}
237 void epsilon(
double distance) {m_epsilon = distance;}
253 CART::TreeBuilder<RealVector,CART::MSECriterion> builder;
254 builder.m_min_samples_leaf = m_min_samples_leaf;
255 builder.m_min_split = m_min_split;
256 builder.m_max_depth = m_max_depth;
257 builder.m_min_impurity_split = m_min_impurity_split;
258 builder.m_epsilon = m_epsilon;
259 builder.m_max_features = m_max_features? m_max_features:
inputDimension(dataset)/3;
261 blas::matrix<double, blas::column_major> data_train = createBatch<RealVector>(dataset.
inputs().
elements().begin(),dataset.
inputs().
elements().end());
263 auto weights_train = createBatch<double>(dataset.weights().elements().begin(),dataset.weights().elements().end());
266 std::vector<unsigned int> seeds(m_numTrees);
267 for (
auto& seed: seeds) {
271 std::vector<std::vector<std::size_t> > complements;
275 random::rng_type rng{seeds[t]};
278 CART::Bootstrap<blas::matrix<double, blas::column_major>, RealMatrix>
bootstrap(rng, data_train,labels_train, weights_train);
279 auto const& tree = builder.buildTree(rng, bootstrap);
283 complements.push_back(std::move(bootstrap.complement));
287 if(m_computeOOBerror)
290 if(m_computeFeatureImportances)
296 bool m_computeFeatureImportances;
297 bool m_computeOOBerror;
300 std::size_t m_max_features;
301 std::size_t m_min_samples_leaf;
302 std::size_t m_min_split;
303 std::size_t m_max_depth;
305 double m_min_impurity_split;