McSvm.cpp
Go to the documentation of this file.
1 #include <cstdio>
2 #include <tuple>
3 
4 #include <shark/LinAlg/Base.h>
5 #include <shark/Core/Random.h>
6 #include <shark/Data/Dataset.h>
11 
12 
13 using namespace shark;
14 
15 
16 // data generating distribution for our toy
17 // multi-category classification problem
18 /// @cond EXAMPLE_SYMBOLS
19 class Problem : public LabeledDataDistribution<RealVector, unsigned int>
20 {
21 public:
22  void draw(RealVector& input, unsigned int& label)const
23  {
24  label = random::discrete(random::globalRng, 0, 4);
25  input.resize(1);
26  input(0) = random::gauss(random::globalRng) + 3.0 * label;
27  }
28 };
29 /// @endcond
30 
31 int main()
32 {
33  // experiment settings
34  unsigned int ell = 30;
35  unsigned int tests = 100;
36  double C = 10.0;
37  double gamma = 0.5;
38 
39  // generate a very simple dataset with a little noise
40  Problem problem;
41  ClassificationDataset training = problem.generateDataset(ell);
42  ClassificationDataset test = problem.generateDataset(tests);
43 
44  // kernel function
45  GaussianRbfKernel<> kernel(gamma);
46 
47  // SVM kernel classifiers
49 
50  // loss measuring classification errors
52 
53  // There are 9 trainers for multi-class SVMs in Shark which can train with or without bias:
54  std::tuple<std::string,McSvm,bool> machines[18] ={
55  std::make_tuple("OVA", McSvm::OVA,false),
56  std::make_tuple("CS", McSvm::CS,false),
57  std::make_tuple("WW",McSvm::WW,false),
58  std::make_tuple("LLW",McSvm::LLW,false),
59  std::make_tuple("ADM",McSvm::ADM,false),
60  std::make_tuple("ATS",McSvm::ATS,false),
61  std::make_tuple("ATM",McSvm::ATM,false),
62  std::make_tuple("MMR",McSvm::MMR,false),
63  std::make_tuple("ReinforcedSvm",McSvm::ReinforcedSvm,false),
64  std::make_tuple("OVA", McSvm::OVA,true),
65  std::make_tuple("CS", McSvm::CS,true),
66  std::make_tuple("WW",McSvm::WW,true),
67  std::make_tuple("LLW",McSvm::LLW,true),
68  std::make_tuple("ADM",McSvm::ADM,true),
69  std::make_tuple("ATS",McSvm::ATS,true),
70  std::make_tuple("ATM",McSvm::ATM,true),
71  std::make_tuple("MMR",McSvm::MMR,true),
72  std::make_tuple("ReinforcedSvm",McSvm::ReinforcedSvm,true)
73  };
74 
75  std::printf("SHARK multi-class SVM example - training 18 machines:\n");
76  for (int i=0; i<18; i++)
77  {
78  CSvmTrainer<RealVector> trainer(&kernel, C, std::get<2>(machines[i]));
79  trainer.setMcSvmType(std::get<1>(machines[i]));
80  trainer.train(svm, training);
81  Data<unsigned int> output = svm(training.inputs());
82  double train_error = loss.eval(training.labels(), output);
83  output = svm(test.inputs());
84  double test_error = loss.eval(test.labels(), output);
85 
86  std::cout<<std::get<0>(machines[i])<<(trainer.trainOffset()? " w bias ":" w/o bias");
87  std::cout<<"\ttraining error="<<train_error;
88  std::cout<<"\ttest error="<<test_error;
89  std::cout<<"\titerations="<<trainer.solutionProperties().iterations;
90  std::cout<<"\ttime="<<trainer.solutionProperties().seconds<<std::endl;
91 
92 
93  }
94 }