Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
examples
Benchmark
shark
nearest_neighbours.cpp
Go to the documentation of this file.
1
#include <
shark/Data/SparseData.h
>
2
#include <
shark/ObjectiveFunctions/Loss/ZeroOneLoss.h
>
3
#include <shark/Models/NearestNeighborClassifier.h>
4
#include <
shark/Algorithms/NearestNeighbors/TreeNearestNeighbors.h
>
5
#include <
shark/Algorithms/NearestNeighbors/SimpleNearestNeighbors.h
>
6
#include <
shark/Models/Trees/KDTree.h
>
7
#include <
shark/Models/Kernels/LinearKernel.h
>
8
9
#include <
shark/Core/Timer.h
>
10
#include <iostream>
11
using namespace
shark
;
12
using namespace
std
;
13
14
int
main
(
int
argc,
char
**argv) {
15
LabeledData<RealVector,unsigned int>
data;
16
importSparseData
(data,
"cod-rna"
,0,8192);
17
18
LabeledData<RealVector,unsigned int>
mnist;
19
importSparseData
(mnist,
"mnist"
,0,8192);
20
//~ {
21
//~ Timer time;
22
//~ KDTree<RealVector> kdtree(data.inputs());
23
//~ TreeNearestNeighbors<RealVector,unsigned int> algorithmKD(data,&kdtree);
24
//~ NearestNeighborClassifier<RealVector> model(&algorithmKD, 10);
25
//~ ZeroOneLoss<> loss;
26
//~ double error = loss(data.labels(),model(data.inputs()));
27
//~ double time_taken = time.stop();
28
29
//~ cout << "kdtree: "<< time_taken <<" "<< error<<std::endl;
30
//~ }
31
32
{
33
Timer
time;
34
LinearKernel<RealVector>
euclideanKernel;
35
SimpleNearestNeighbors<RealVector,unsigned int>
simpleAlgorithm(data,&euclideanKernel);
36
NearestNeighborClassifier<RealVector> model(&simpleAlgorithm, 10);
37
ZeroOneLoss<>
loss;
38
double
error = loss(data.
labels
(),model(data.
inputs
()));
39
double
time_taken = time.
stop
();
40
41
cout <<
"brute-force: "
<< time_taken <<
" "
<< error<<std::endl;
42
}
43
44
{
45
Timer
time;
46
LinearKernel<RealVector>
euclideanKernel;
47
SimpleNearestNeighbors<RealVector,unsigned int>
simpleAlgorithm(mnist,&euclideanKernel);
48
NearestNeighborClassifier<RealVector> model(&simpleAlgorithm, 10);
49
ZeroOneLoss<>
loss;
50
double
error = loss(mnist.
labels
(),model(mnist.
inputs
()));
51
double
time_taken = time.
stop
();
52
53
cout <<
"brute-force-mnist: "
<< time_taken <<
" "
<< error<<std::endl;
54
}
55
56
}