VersatileClassificationTutorial-NN.cpp
Go to the documentation of this file.
1 
2 #include <shark/Data/Dataset.h>
3 #include <shark/Data/Csv.h>
5 
7 #include <shark/Models/NearestNeighborClassifier.h>
9 
10 
11 using namespace shark;
12 
13 int main()
14 {
15  // Load data, use 70% for training and 30% for testing.
16  // The path is hard coded; make sure to invoke the executable
17  // from a place where the data file can be found. It is located
18  // under [shark]/examples/Supervised/data.
19  ClassificationDataset traindata, testdata;
20  importCSV(traindata, "data/quickstartData.csv", LAST_COLUMN, ' ');
21  testdata = splitAtElement(traindata, 70 * traindata.numberOfElements() / 100);
22 
23  unsigned int k = 3; // number of neighbors
24  KDTree<RealVector> tree(traindata.inputs());
25  TreeNearestNeighbors<RealVector, unsigned int> algorithm(traindata, &tree);
26  NearestNeighborClassifier<RealVector> model(&algorithm, k);
27 
28  Data<unsigned int> prediction = model(testdata.inputs());
29 
31  double error_rate = loss(testdata.labels(), prediction);
32 
33  std::cout << "model: " << model.name() << std::endl
34  << "test error rate: " << error_rate << std::endl;
35 }