CARTTutorial.cpp
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief CART Tutorial Sample Code
6  *
7  * This file is part of the "CART" tutorial.
8  * It requires some toy sample data that comes with the library.
9  *
10  *
11  *
12  * \author K. N. Hansen
13  * \date 2012
14  *
15  *
16  * \par Copyright 1995-2017 Shark Development Team
17  *
18  * <BR><HR>
19  * This file is part of Shark.
20  * <http://shark-ml.org/>
21  *
22  * Shark is free software: you can redistribute it and/or modify
23  * it under the terms of the GNU Lesser General Public License as published
24  * by the Free Software Foundation, either version 3 of the License, or
25  * (at your option) any later version.
26  *
27  * Shark is distributed in the hope that it will be useful,
28  * but WITHOUT ANY WARRANTY; without even the implied warranty of
29  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
30  * GNU Lesser General Public License for more details.
31  *
32  * You should have received a copy of the GNU Lesser General Public License
33  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
34  *
35  */
36 //===========================================================================
37 
38 #include <shark/Data/Csv.h> // importing CSV files
39 #include <shark/Algorithms/Trainers/CARTTrainer.h> // the CART trainer
40 #include <shark/ObjectiveFunctions/Loss/ZeroOneLoss.h> // 0/1 loss for evaluation
41 #include <iostream>
42 
43 using namespace std;
44 using namespace shark;
45 
46 
47 int main() {
48 
49  //*****************LOAD AND PREPARE DATA***********************//
50 
51  // read data
52  ClassificationDataset dataTrain;
53  importCSV(dataTrain, "data/C.csv", LAST_COLUMN, ' ');
54 
55 
56  //Split the dataset into a training and a test dataset
57  ClassificationDataset dataTest =splitAtElement(dataTrain,311);
58 
59  cout << "Training set - number of data points: " << dataTrain.numberOfElements()
60  << " number of classes: " << numberOfClasses(dataTrain)
61  << " input dimension: " << inputDimension(dataTrain) << endl;
62 
63  cout << "Test set - number of data points: " << dataTest.numberOfElements()
64  << " number of classes: " << numberOfClasses(dataTest)
65  << " input dimension: " << inputDimension(dataTest) << endl;
66 
67 
68  //Train the model
69  CARTTrainer trainer;
70  CARTClassifier<RealVector> model;
71  trainer.train(model, dataTrain);
72 
73  // evaluate Random Forest classifier
75  Data<RealVector> prediction = model(dataTrain.inputs());
76  cout << "CART on training set accuracy: " << 1. - loss.eval(dataTrain.labels(), prediction) << endl;
77 
78  prediction = model(dataTest.inputs());
79  cout << "CART on test set accuracy: " << 1. - loss.eval(dataTest.labels(), prediction) << endl;
80 
81 }