Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
examples
Supervised
KNNCrossValidationTutorial.cpp
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief Nearest Neighbor Tutorial Sample Code
6
*
7
*
8
*
9
*
10
* \author C. Igel
11
* \date 2011
12
*
13
*
14
* \par Copyright 1995-2017 Shark Development Team
15
*
16
* <BR><HR>
17
* This file is part of Shark.
18
* <http://shark-ml.org/>
19
*
20
* Shark is free software: you can redistribute it and/or modify
21
* it under the terms of the GNU Lesser General Public License as published
22
* by the Free Software Foundation, either version 3 of the License, or
23
* (at your option) any later version.
24
*
25
* Shark is distributed in the hope that it will be useful,
26
* but WITHOUT ANY WARRANTY; without even the implied warranty of
27
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
28
* GNU Lesser General Public License for more details.
29
*
30
* You should have received a copy of the GNU Lesser General Public License
31
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
32
*
33
*/
34
//===========================================================================
35
36
#include <
shark/Data/Csv.h
>
37
#include <
shark/Models/NearestNeighborModel.h
>
38
#include <
shark/Algorithms/NearestNeighbors/SimpleNearestNeighbors.h
>
39
#include <
shark/Models/Kernels/LinearKernel.h
>
40
#include <
shark/ObjectiveFunctions/Loss/ZeroOneLoss.h
>
41
#include <
shark/Data/CVDatasetTools.h
>
42
#include <iostream>
43
44
using namespace
shark
;
45
using namespace
std
;
46
47
int
main
(
int
argc,
char
**argv) {
48
if
(argc < 2) {
49
cerr <<
"usage: "
<< argv[0] <<
" (filename)"
<< endl;
50
exit(EXIT_FAILURE);
51
}
52
// read data
53
ClassificationDataset
data;
54
try
{
55
importCSV
(data, argv[1],
LAST_COLUMN
,
' '
);
56
}
57
catch
(...) {
58
cerr <<
"unable to read data from file "
<< argv[1] << endl;
59
exit(EXIT_FAILURE);
60
}
61
62
cout <<
"number of data points: "
<< data.
numberOfElements
()
63
<<
" number of classes: "
<<
numberOfClasses
(data)
64
<<
" input dimension: "
<<
inputDimension
(data) << endl;
65
66
// split data into training and test set
67
ClassificationDataset
dataTest =
splitAtElement
(data, .5 * data.
numberOfElements
());
68
cout <<
"training data points: "
<< data.
numberOfElements
() << endl;
69
cout <<
"test data points: "
<< dataTest.
numberOfElements
() << endl;
70
71
//create 10 CV-Folds
72
const
unsigned
int
NFolds= 10;
73
CVFolds<ClassificationDataset>
folds =
createCVSameSizeBalanced
(data, NFolds);
74
75
//we have 5 different values of k to test
76
unsigned
int
k[]={1,3,5,7,9};
77
unsigned
int
numParameters = 5;
78
79
ZeroOneLoss<unsigned int>
loss;
//loss for evaluation
80
LinearKernel<>
metric;
//linear distance measure
81
82
//find best #-neighbors using CV
83
unsigned
int
best_k = 0;
84
double
best_error = 2;
//maximum 0-1loss is 1
85
//for every parameter....
86
for
(std::size_t p = 0; p != numParameters; ++p){
87
double
error = 0;
88
//calculate CV-error
89
for
(std::size_t i = 0; i != NFolds; ++i){
90
SimpleNearestNeighbors<RealVector, unsigned int>
algorithm(folds.
training
(i), &metric);
91
NearestNeighborModel<RealVector, unsigned int>
KNN(&algorithm, k[p]);
92
error += loss(folds.
validation
(i).
labels
(),KNN(folds.
validation
(i).
inputs
()));
93
}
94
error /=NFolds;
95
//print cv-error for current parameter
96
std::cout<<k[p]<<
" "
<<error<<std::endl;
97
//if the error is better, we keep it.
98
if
(error < best_error){
99
best_k = k[p];
100
best_error = error;
101
}
102
}
103
//evaluate the best paramter found on test set using the full training set
104
SimpleNearestNeighbors<RealVector, unsigned int>
algorithm(data, &metric);
105
NearestNeighborModel<RealVector, unsigned int>
KNN(&algorithm, best_k);
106
std::cout<<
"NearestNeighbors: "
<< loss(dataTest.
labels
(),KNN(dataTest.
inputs
()))<<
'\n'
;
107
std::cout<<
"K: "
<<best_k<<std::endl;
108
}