FFNNBasicTutorial.cpp
Go to the documentation of this file.
1 //the model
2 #include <shark/Models/LinearModel.h>//single dense layer
3 #include <shark/Models/ConcatenatedModel.h>//for stacking layers, proveides operator>>
4 //training the model
5 #include <shark/ObjectiveFunctions/ErrorFunction.h>//error function, allows for minibatch training
6 #include <shark/ObjectiveFunctions/Loss/CrossEntropy.h> // loss used for supervised training
7 #include <shark/ObjectiveFunctions/Loss/ZeroOneLoss.h> // loss used for evaluation of performance
8 #include <shark/Algorithms/GradientDescent/Adam.h> //optimizer: simple gradient descent.
9 #include <shark/Data/SparseData.h> //loading the dataset
10 using namespace shark;
11 
12 int main(int argc, char **argv)
13 {
14  if(argc < 2) {
15  std::cerr << "usage: " << argv[0] << " path/to/mnist_subset.libsvm" << std::endl;
16  return 1;
17  }
18  std::size_t hidden1 = 200;
19  std::size_t hidden2 = 100;
20  std::size_t iterations = 1000;
21 
22  std::size_t batchSize = 256;
24  importSparseData( data, argv[1], 0, batchSize );
25  data.shuffle(); //shuffle data randomly
26  auto test = splitAtElement(data, 70 * data.numberOfElements() / 100);//split a test set
27  std::size_t numClasses = numberOfClasses(data);
28  //We use a dense linear model with rectifier activations
30 
31  //build the network
32  DenseLayer layer1(data.inputShape(),hidden1);
33  DenseLayer layer2(layer1.outputShape(),hidden2);
34  LinearModel<RealVector> output(layer2.outputShape(),numClasses);
35  auto network = layer1 >> layer2 >> output;
36  //create the supervised problem.
37  CrossEntropy loss;
38  ErrorFunction error(data, &network, &loss, true);//enable minibatch training
39 
40  //optimize the model
41  std::cout<<"training network"<<std::endl;
42  initRandomNormal(network,0.001);
43  Adam optimizer;
44  error.init();
45  optimizer.init(error);
46  for(std::size_t i = 0; i != iterations; ++i){
47  optimizer.step(error);
48  std::cout<<i<<" "<<optimizer.solution().value<<std::endl;
49  }
50  network.setParameterVector(optimizer.solution().point);
51 
52  //evaluation
54  Data<RealVector> predictionTrain = network(data.inputs());
55  std::cout << "classification error,train: " << loss01.eval(data.labels(), predictionTrain) << std::endl;
56 
57  Data<RealVector> prediction = network(test.inputs());
58  std::cout << "classification error,test: " << loss01.eval(test.labels(), prediction) << std::endl;
59 
60 }
61