CSvmGridSearchTutorial.cpp
Go to the documentation of this file.
7 #include <shark/Data/Dataset.h>
9 
10 
11 using namespace shark;
12 using namespace std;
13 
14 
15 int main()
16 {
17  // problem definition
18  Chessboard prob;
19  ClassificationDataset dataTrain = prob.generateDataset(200);
20  ClassificationDataset dataTest= prob.generateDataset(10000);
21 
22  // SVM setup
23  GaussianRbfKernel<> kernel(0.5, true); //unconstrained?
24  KernelExpansion<RealVector> svm(true); //use offset?
25  CSvmTrainer<RealVector> trainer(&kernel, 1.0, true); //unconstrained?
26 
27  // cross-validation error
28  const unsigned int N= 5; // number of folds
29 
33  folds, &trainer, &svm, &trainer, &loss
34  );
35 
36 
37  // find best parameters
38 
39  // use Jaakkola's heuristic as a starting point for the grid-search
40  JaakkolaHeuristic ja(dataTrain);
41  double ljg = log(ja.gamma());
42  cout << "Tommi Jaakkola says gamma = " << ja.gamma() << " and ln(gamma) = " << ljg << endl;
43 
44 
45  GridSearch grid;
46  vector<double> min(2);
47  vector<double> max(2);
48  vector<size_t> sections(2);
49  min[0] = ljg-4.; max[0] = ljg+4; sections[0] = 17; // kernel parameter gamma
50  min[1] = 0.0; max[1] = 10.0; sections[1] = 11; // regularization parameter C
51  grid.configure(min, max, sections);
52  grid.step(cvError);
53 
54  // train model on the full dataset
55  trainer.setParameterVector(grid.solution().point);
56  trainer.train(svm, dataTrain);
57  cout << "C =\t" << trainer.C() << endl;
58  cout << "gamma =\t" << kernel.gamma() << endl;
59 
60  // evaluate
61  Data<RealVector> output = svm(dataTrain.inputs());
62  double train_error = loss.eval(dataTrain.labels(), output);
63  cout << "training error:\t" << train_error << endl;
64  output = svm(dataTest.inputs());
65  double test_error = loss.eval(dataTest.labels(), output);
66  cout << "test error: \t" << test_error << endl;
67 }