logistic_regression_SAG.cpp
Go to the documentation of this file.
4 
5 #include <shark/Core/Timer.h>
6 #include <iostream>
7 using namespace shark;
8 using namespace std;
9 
10 
11 template<class InputType>
12 void run(LabeledData<InputType,unsigned int> const& data, double alpha, unsigned int epochs){
13  CrossEntropy loss;
15 
16 
17  LinearSAGTrainer<InputType,unsigned int> trainer(&loss,alpha);
18  trainer.setEpochs(epochs);
19 
20  Timer time;
21  trainer.train(model, data);
22  double time_taken = time.stop();
23 
24  cout << "Cross-Entropy: " << loss(data.labels(),model.decisionFunction()(data.inputs()))<<std::endl;
25  cout << "Time:\n" << time_taken << endl;
26 }
27 int main(int argc, char **argv) {
28  ClassificationDataset data_dense;
29  importSparseData(data_dense, "mnist",0,8192);
30  data_dense = transformLabels(data_dense, [](unsigned int y){ return y%2;});
32  importSparseData(data_sparse, "rcv1_train.binary",0,8192);
33 
34  double alpha = 0.1;
35  run(data_dense, alpha, 200);
36  run(data_sparse, alpha, 2000);
37 
38 }