Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
examples
Supervised
CARTTutorial.cpp
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief CART Tutorial Sample Code
6
*
7
* This file is part of the "CART" tutorial.
8
* It requires some toy sample data that comes with the library.
9
*
10
*
11
*
12
* \author K. N. Hansen
13
* \date 2012
14
*
15
*
16
* \par Copyright 1995-2017 Shark Development Team
17
*
18
* <BR><HR>
19
* This file is part of Shark.
20
* <http://shark-ml.org/>
21
*
22
* Shark is free software: you can redistribute it and/or modify
23
* it under the terms of the GNU Lesser General Public License as published
24
* by the Free Software Foundation, either version 3 of the License, or
25
* (at your option) any later version.
26
*
27
* Shark is distributed in the hope that it will be useful,
28
* but WITHOUT ANY WARRANTY; without even the implied warranty of
29
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
30
* GNU Lesser General Public License for more details.
31
*
32
* You should have received a copy of the GNU Lesser General Public License
33
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
34
*
35
*/
36
//===========================================================================
37
38
#include <
shark/Data/Csv.h
>
// importing CSV files
39
#include <shark/Algorithms/Trainers/CARTTrainer.h>
// the CART trainer
40
#include <
shark/ObjectiveFunctions/Loss/ZeroOneLoss.h
>
// 0/1 loss for evaluation
41
#include <iostream>
42
43
using namespace
std
;
44
using namespace
shark
;
45
46
47
int
main
() {
48
49
//*****************LOAD AND PREPARE DATA***********************//
50
51
// read data
52
ClassificationDataset
dataTrain;
53
importCSV
(dataTrain,
"data/C.csv"
,
LAST_COLUMN
,
' '
);
54
55
56
//Split the dataset into a training and a test dataset
57
ClassificationDataset
dataTest =
splitAtElement
(dataTrain,311);
58
59
cout <<
"Training set - number of data points: "
<< dataTrain.
numberOfElements
()
60
<<
" number of classes: "
<<
numberOfClasses
(dataTrain)
61
<<
" input dimension: "
<<
inputDimension
(dataTrain) << endl;
62
63
cout <<
"Test set - number of data points: "
<< dataTest.
numberOfElements
()
64
<<
" number of classes: "
<<
numberOfClasses
(dataTest)
65
<<
" input dimension: "
<<
inputDimension
(dataTest) << endl;
66
67
68
//Train the model
69
CARTTrainer trainer;
70
CARTClassifier<RealVector> model;
71
trainer.train(model, dataTrain);
72
73
// evaluate Random Forest classifier
74
ZeroOneLoss<unsigned int, RealVector>
loss;
75
Data<RealVector>
prediction = model(dataTrain.
inputs
());
76
cout <<
"CART on training set accuracy: "
<< 1. - loss.
eval
(dataTrain.
labels
(), prediction) << endl;
77
78
prediction = model(dataTest.
inputs
());
79
cout <<
"CART on test set accuracy: "
<< 1. - loss.
eval
(dataTest.
labels
(), prediction) << endl;
80
81
}