CSvmGridSearchTutorial.cpp
Go to the documentation of this file.
5 
9 
10 using namespace shark;
11 using namespace std;
12 
13 
14 int main()
15 {
16  // problem definition
17  Chessboard prob;
18  ClassificationDataset dataTrain = prob.generateDataset(200);
19  ClassificationDataset dataTest= prob.generateDataset(10000);
20 
21  // SVM setup
22  GaussianRbfKernel<> kernel(0.5, true); //unconstrained?
24  bool offset = true;
25  bool unconstrained = true;
26  CSvmTrainer<RealVector> trainer(&kernel, 1.0, offset,unconstrained);
27 
28  // cross-validation error
29  const unsigned int K = 5; // number of folds
33  folds, &trainer, &svm, &trainer, &loss
34  );
35 
36 
37  // find best parameters
38 
39  // use Jaakkola's heuristic as a starting point for the grid-search
40  JaakkolaHeuristic ja(dataTrain);
41  double ljg = log(ja.gamma());
42  cout << "Tommi Jaakkola says gamma = " << ja.gamma() << " and ln(gamma) = " << ljg << endl;
43 
44  GridSearch grid;
45  vector<double> min(2);
46  vector<double> max(2);
47  vector<size_t> sections(2);
48  // kernel parameter gamma
49  min[0] = ljg-4.; max[0] = ljg+4; sections[0] = 9;
50  // regularization parameter C
51  min[1] = 0.0; max[1] = 10.0; sections[1] = 11;
52  grid.configure(min, max, sections);
53  grid.step(cvError);
54 
55  // train model on the full dataset
56  trainer.setParameterVector(grid.solution().point);
57  trainer.train(svm, dataTrain);
58  cout << "grid.solution().point " << grid.solution().point << endl;
59  cout << "C =\t" << trainer.C() << endl;
60  cout << "gamma =\t" << kernel.gamma() << endl;
61 
62  // evaluate
63  Data<unsigned int> output = svm(dataTrain.inputs());
64  double train_error = loss.eval(dataTrain.labels(), output);
65  cout << "training error:\t" << train_error << endl;
66  output = svm(dataTest.inputs());
67  double test_error = loss.eval(dataTest.labels(), output);
68  cout << "test error: \t" << test_error << endl;
69 }