MultiTaskSvm.cpp
Go to the documentation of this file.
1 
7 
8 using namespace shark;
9 using namespace std;
10 
11 
12 // RealVector input with task index
14 
15 
16 // Multi-task problem with up to three tasks.
17 class MultiTaskProblem : public LabeledDataDistribution<InputType, unsigned int>
18 {
19 public:
20  MultiTaskProblem()
21  {
22  m_task[0] = true;
23  m_task[1] = true;
24  m_task[2] = true;
25  }
26 
27  void setTasks(bool task0, bool task1, bool task2)
28  {
29  m_task[0] = task0;
30  m_task[1] = task1;
31  m_task[2] = task2;
32  }
33 
34  void draw(InputType& input, unsigned int& label) const
35  {
36  size_t taskindex = 0;
37  do {
38  taskindex = random::uni(random::globalRng, 0, 2);
39  } while (! m_task[taskindex]);
40  double x1 = random::gauss(random::globalRng);
41  double x2 = 3.0 * random::gauss(random::globalRng);
42  unsigned int y = (x1 > 0.0) ? 1 : 0;
43  double alpha = 0.05 * M_PI * taskindex;
44  input.input.resize(2);
45  input.input(0) = cos(alpha) * x1 - sin(alpha) * x2;
46  input.input(1) = sin(alpha) * x1 + cos(alpha) * x2;
47  input.task = taskindex;
48  label = y;
49  }
50 
51 protected:
52  bool m_task[3];
53 };
54 
55 
56 int main(int argc, char** argv)
57 {
58  // experiment settings
59  unsigned int ell_train = 1000; // number of training data point from tasks 0 and 1
60  unsigned int ell_test = 1000; // number of test data points from task 2
61  double C = 1.0; // regularization parameter
62  double gamma = 0.5; // kernel bandwidth parameter
63 
64  // generate data
65  MultiTaskProblem problem;
66  problem.setTasks(true, true, false);
67  LabeledData<InputType, unsigned int> training = problem.generateDataset(ell_train);
68  problem.setTasks(false, false, true);
69  LabeledData<InputType, unsigned int> test = problem.generateDataset(ell_test);
70 
71  // merge all inputs into a single data object
72  Data<InputType> data(ell_train + ell_test);
73  for (size_t i=0; i<ell_train; i++)
74  data.element(i) = training.inputs().element(i);
75  for (size_t i=0; i<ell_test; i++)
76  data.element(ell_train + i) = test.inputs().element(i);
77 
78  // create kernel objects
79  GaussianRbfKernel<RealVector> inputKernel(gamma); // Gaussian kernel on inputs
80  GaussianTaskKernel<RealVector> taskKernel( // task similarity kernel
81  data, // all inputs with task indices, no labels
82  3, // total number of tasks
83  inputKernel, // base kernel for input similarity
84  gamma); // bandwidth for task similarity kernel
85  MultiTaskKernel<RealVector> multiTaskKernel(&inputKernel, &taskKernel);
86 
87  // train the SVM
89  CSvmTrainer<InputType> trainer(&multiTaskKernel, C,false);
90  cout << "training ..." << endl;
91  trainer.train(ke, training);
92  cout << "done." << endl;
93 
95  Data<RealVector> output;
96 
97  // evaluate training performance
98  double trainError = loss.eval(training.labels(), ke(training.inputs()));
99  cout << "training error:\t" << trainError << endl;
100 
101  // evaluate its transfer performance
102  double testError = loss.eval(test.labels(), ke(test.inputs()));
103  cout << "test error:\t" << testError << endl;
104 }