CVFolds.cpp
Go to the documentation of this file.
1 //header needed for cross validation
3 
4 //headers needed for our test problem
12 
13 //we use an artifical learning problem
15 
16 using namespace shark;
17 using namespace std;
18 
19 ///In this example, you will learn to create and use partitions
20 ///for cross validation.
21 ///This tutorial describes a handmade solution which does not use the Crossvalidation error function
22 ///which is also provided by shark. We do this, because it gives a better on what Cross Validation does.
23 
24 ///The Test Problem receives the regularization parameter and a dataset
25 ///and returns the errror. skip to the main if you are not interested
26 ///in the problem itself. But here you can also see how to create
27 ///regularized error functions. so maybe it's still worth taking a look ;)
28 double trainProblem(const RegressionDataset& training, RegressionDataset const& validation, double regularization){
30  LinearModel<RealVector> layer2(20,1);
31  ConcatenatedModel<RealVector> network = layer1 >> layer2;
32  initRandomUniform(network,-1,1);
33 
34  //the error function is a combination of MSE and 2-norm error
35  SquaredLoss<> loss;
36  ErrorFunction error(training,&network,&loss);
37  TwoNormRegularizer regularizer;
38  error.setRegularizer(regularization, &regularizer);
39 
40  //now train for a number of iterations using Rprop
41  IRpropPlus optimizer;
42  error.init();
43  //initialize with our predefined point, since
44  //the combined function can't propose one.
45  optimizer.init(error);
46  for(unsigned iter = 0; iter != 5000; ++iter)
47  {
48  optimizer.step(error);
49  }
50 
51  //validate and return the error without regularization
52  return loss(network(validation.inputs()),validation.labels());
53 }
54 
55 
56 /// What is Cross Validation(CV)? In Cross Validation the dataset is partitioned in
57 /// several validation data sets. For a given validation dataset the remainder of the dataset
58 /// - every other validation set - forms the training part. During every evaluation of the error function,
59 /// the problem is solved using the training part and the final error is computed using the validation part.
60 /// The mean of all validation sets trained this way is the final error of the solution found.
61 /// This quite complex procedure is used to minimize the bias introduced by the dataset itself and makes
62 /// overfitting of the solution harder.
63 int main(){
64  //we first create the problem. in this simple tutorial,
65  //it's only the 1D wave function sin(x)/x + noise
66  Wave wave;
67  RegressionDataset dataset;
68  dataset = wave.generateDataset(100);
69 
70  //now we want to create the cv folds. For this, we
71  //use the CVDatasetTools.h. There are a few functions
72  //to create folds. in this case, we create 4
73  //partitions with the same size. so we have 75 train
74  //and 25 validation data points
76 
77  //now we want to use the folds to find the best regularization
78  //parameter for our problem. we use a grid search to accomplish this
79  double bestValidationError = 1e4;
80  double bestRegularization = 0;
81  for (double regularization = 1.e-5; regularization < 1.e-3; regularization *= 2) {
82  double result = 0;
83  for (std::size_t fold = 0; fold != folds.size(); ++fold){ //CV
84  // access the fold
85  RegressionDataset training = folds.training(fold);
86  RegressionDataset validation = folds.validation(fold);
87  // train
88  result += trainProblem(training, validation, regularization);
89  }
90  result /= folds.size();
91 
92  // check whether this regularization parameter leads to better results
93  if (result < bestValidationError)
94  {
95  bestValidationError = result;
96  bestRegularization = regularization;
97  }
98 
99  // print status:
100  std::cout << regularization << " " << result << std::endl;
101  }
102 
103  // print the best value found
104  cout << "RESULTS: " << std::endl;
105  cout << "======== " << std::endl;
106  cout << "best validation error: " << bestValidationError << std::endl;
107  cout << "best regularization: " << bestRegularization<< std::endl;
108 }