Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Class list
Global functions
FAQ
Showroom
Main Page
Related Pages
Modules
Classes
examples
Supervised
CSvmGridSearchTutorial.cpp
Go to the documentation of this file.
1
#include <
shark/Models/Kernels/GaussianRbfKernel.h
>
2
#include <
shark/ObjectiveFunctions/Loss/ZeroOneLoss.h
>
3
#include <
shark/ObjectiveFunctions/CrossValidationError.h
>
4
#include <
shark/Algorithms/Trainers/CSvmTrainer.h
>
5
#include <
shark/Algorithms/DirectSearch/GridSearch.h
>
6
#include <
shark/Algorithms/JaakkolaHeuristic.h
>
7
#include <
shark/Data/Dataset.h
>
8
#include <
shark/Data/DataDistribution.h
>
9
10
11
using namespace
shark;
12
using namespace
std;
13
14
15
int
main
()
16
{
17
// problem definition
18
Chessboard
prob;
19
ClassificationDataset
dataTrain = prob.
generateDataset
(200);
20
ClassificationDataset
dataTest= prob.
generateDataset
(10000);
21
22
// SVM setup
23
GaussianRbfKernel<>
kernel(0.5,
true
);
//unconstrained?
24
KernelExpansion<RealVector>
svm(
true
);
//use offset?
25
CSvmTrainer<RealVector>
trainer(&kernel, 1.0,
true
);
//unconstrained?
26
27
// cross-validation error
28
const
unsigned
int
N= 5;
// number of folds
29
30
ZeroOneLoss<unsigned int, RealVector>
loss;
31
CVFolds<ClassificationDataset>
folds =
createCVSameSizeBalanced
(dataTrain, N);
32
CrossValidationError<KernelExpansion<RealVector>
,
unsigned
int
> cvError(
33
folds, &trainer, &svm, &trainer, &loss
34
);
35
36
37
// find best parameters
38
39
// use Jaakkola's heuristic as a starting point for the grid-search
40
JaakkolaHeuristic
ja(dataTrain);
41
double
ljg = log(ja.
gamma
());
42
cout <<
"Tommi Jaakkola says gamma = "
<< ja.
gamma
() <<
" and ln(gamma) = "
<< ljg << endl;
43
44
45
GridSearch
grid;
46
vector<double>
min
(2);
47
vector<double>
max
(2);
48
vector<size_t> sections(2);
49
min[0] = ljg-4.; max[0] = ljg+4; sections[0] = 17;
// kernel parameter gamma
50
min[1] = 0.0; max[1] = 10.0; sections[1] = 11;
// regularization parameter C
51
grid.
configure
(min, max, sections);
52
grid.
step
(cvError);
53
54
// train model on the full dataset
55
trainer.
setParameterVector
(grid.
solution
().point);
56
trainer.
train
(svm, dataTrain);
57
cout <<
"C =\t"
<< trainer.
C
() << endl;
58
cout <<
"gamma =\t"
<< kernel.
gamma
() << endl;
59
60
// evaluate
61
Data<RealVector>
output = svm(dataTrain.
inputs
());
62
double
train_error = loss.
eval
(dataTrain.
labels
(), output);
63
cout <<
"training error:\t"
<< train_error << endl;
64
output = svm(dataTest.
inputs
());
65
double
test_error = loss.
eval
(dataTest.
labels
(), output);
66
cout <<
"test error: \t"
<< test_error << endl;
67
}