Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
examples
Supervised
CSvmGridSearchTutorial.cpp
Go to the documentation of this file.
1
#include <
shark/Models/Kernels/GaussianRbfKernel.h
>
2
#include <
shark/ObjectiveFunctions/Loss/ZeroOneLoss.h
>
3
#include <
shark/Algorithms/Trainers/CSvmTrainer.h
>
4
#include <
shark/Data/DataDistribution.h
>
5
6
#include <
shark/ObjectiveFunctions/CrossValidationError.h
>
7
#include <
shark/Algorithms/DirectSearch/GridSearch.h
>
8
#include <
shark/Algorithms/JaakkolaHeuristic.h
>
9
10
using namespace
shark
;
11
using namespace
std
;
12
13
14
int
main
()
15
{
16
// problem definition
17
Chessboard
prob;
18
ClassificationDataset
dataTrain = prob.
generateDataset
(200);
19
ClassificationDataset
dataTest= prob.
generateDataset
(10000);
20
21
// SVM setup
22
GaussianRbfKernel<>
kernel(0.5,
true
);
//unconstrained?
23
KernelClassifier<RealVector>
svm;
24
bool
offset =
true
;
25
bool
unconstrained =
true
;
26
CSvmTrainer<RealVector>
trainer(&kernel, 1.0, offset,unconstrained);
27
28
// cross-validation error
29
const
unsigned
int
K = 5;
// number of folds
30
ZeroOneLoss<unsigned int>
loss;
31
CVFolds<ClassificationDataset>
folds =
createCVSameSizeBalanced
(dataTrain, K);
32
CrossValidationError<KernelClassifier<RealVector>
,
unsigned
int
> cvError(
33
folds, &trainer, &svm, &trainer, &loss
34
);
35
36
37
// find best parameters
38
39
// use Jaakkola's heuristic as a starting point for the grid-search
40
JaakkolaHeuristic
ja(dataTrain);
41
double
ljg = log(ja.
gamma
());
42
cout <<
"Tommi Jaakkola says gamma = "
<< ja.
gamma
() <<
" and ln(gamma) = "
<< ljg << endl;
43
44
GridSearch
grid;
45
vector<double> min(2);
46
vector<double> max(2);
47
vector<size_t> sections(2);
48
// kernel parameter gamma
49
min[0] = ljg-4.; max[0] = ljg+4; sections[0] = 9;
50
// regularization parameter C
51
min[1] = 0.0; max[1] = 10.0; sections[1] = 11;
52
grid.
configure
(min, max, sections);
53
grid.
step
(cvError);
54
55
// train model on the full dataset
56
trainer.
setParameterVector
(grid.
solution
().
point
);
57
trainer.
train
(svm, dataTrain);
58
cout <<
"grid.solution().point "
<< grid.
solution
().
point
<< endl;
59
cout <<
"C =\t"
<< trainer.
C
() << endl;
60
cout <<
"gamma =\t"
<< kernel.
gamma
() << endl;
61
62
// evaluate
63
Data<unsigned int>
output = svm(dataTrain.
inputs
());
64
double
train_error = loss.
eval
(dataTrain.
labels
(), output);
65
cout <<
"training error:\t"
<< train_error << endl;
66
output = svm(dataTest.
inputs
());
67
double
test_error = loss.
eval
(dataTest.
labels
(), output);
68
cout <<
"test error: \t"
<< test_error << endl;
69
}