35 #ifndef SHARK_MODELS_TREES_RFCLASSIFIER_H 36 #define SHARK_MODELS_TREES_RFCLASSIFIER_H 49 template<
class LabelType>
50 class RFClassifierBase :
public MeanModel<CARTree<LabelType> >{
52 double doComputeOOBerror(
53 UIntMatrix
const& oobPoints, LabeledData<RealVector, RealVector>
const& data
60 for(
auto const& point: data.elements()){
61 noalias(input) = point.input;
63 std::size_t oobModels = 0;
65 if(oobPoints(m,elem)){
67 auto const& model = this->
getModel(m);
68 noalias(
mean) += model(input);
72 OOBerror += 0.5 * norm_sqr(point.label -
mean);
75 OOBerror /= data.numberOfElements();
79 double loss(RealMatrix
const& labels, RealMatrix
const& predictions)
const{
80 SquaredLoss<RealVector, RealVector> loss;
81 return loss.eval(labels, predictions);
86 class RFClassifierBase<unsigned int> :
public Classifier<MeanModel<CARTree<unsigned int> > >{
89 CARTree<unsigned int>
const& getModel(std::size_t index)
const{
90 return this->decisionFunction().getModel(index);
93 void addModel(CARTree<unsigned int>
const& model,
double weight = 1.0){
94 this->decisionFunction().addModel(model,weight);
97 this->decisionFunction().clearModels();
100 void setOutputSize(std::size_t dim){
101 this->decisionFunction().setOutputSize(dim);
105 std::size_t numberOfModels()
const{
106 return this->decisionFunction().numberOfModels();
109 double loss(UIntVector
const& labels, UIntVector
const& predictions)
const{
110 ZeroOneLoss<unsigned int> loss;
111 return loss.eval(labels, predictions);
114 double doComputeOOBerror(
115 UIntMatrix
const& oobPoints, LabeledData<RealVector, unsigned int>
const& data
121 std::size_t elem = 0;
122 for(
auto const& point: data.elements()){
123 noalias(input) = point.input;
125 for(std::size_t m = 0; m != numberOfModels();++m){
126 if(oobPoints(m,elem)){
127 auto const& model = getModel(m);
128 unsigned int label = model(input);
132 OOBerror += (arg_max(votes) != point.label);
135 OOBerror /= data.numberOfElements();
153 template<
class LabelType>
159 {
return "RFClassifier"; }
169 return m_featureImportances;
174 std::size_t n = this->numberOfModels();
175 if(!n)
return UIntVector();
176 UIntVector r = this->getModel(0).countAttributes();
177 for(std::size_t i=1; i< n; i++ ) {
178 noalias(r) += this->getModel(i).countAttributes();
186 for(std::size_t i = 0; i != oobMatrix.size1(); ++i){
187 for(
auto index: oobIndices[i])
188 oobMatrix(i,index) = 1;
190 m_OOBerror = this->doComputeOOBerror(oobMatrix,data);
199 m_featureImportances.resize(inputs);
202 for(std::size_t m = 0; m != this->numberOfModels();++m){
203 auto batch =
subBatch(view, oobIndices[m]);
204 double errorBefore = this->loss(batch.label,this->getModel(m)(batch.input));
206 for(std::size_t i=0; i!=inputs;++i) {
207 RealVector vOld= column(batch.input,i);
210 noalias(column(batch.input,i)) = v;
211 double errorAfter = this->loss(batch.label,this->getModel(m)(batch.input));
212 noalias(column(batch.input,i)) = vOld;
213 m_featureImportances(i) += (errorAfter - errorBefore) / batch.size();
216 m_featureImportances /= this->numberOfModels();
221 RealVector m_featureImportances;