Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
examples
Benchmark
shark
logistic_regression_SAG.cpp
Go to the documentation of this file.
1
#include <
shark/Data/SparseData.h
>
2
#include <
shark/ObjectiveFunctions/Loss/CrossEntropy.h
>
3
#include <
shark/Algorithms/Trainers/LinearSAGTrainer.h
>
4
5
#include <
shark/Core/Timer.h
>
6
#include <iostream>
7
using namespace
shark
;
8
using namespace
std
;
9
10
11
template
<
class
InputType>
12
void
run
(
LabeledData<InputType,unsigned int>
const
& data,
double
alpha,
unsigned
int
epochs){
13
CrossEntropy
loss;
14
LinearClassifier<InputType>
model;
15
16
17
LinearSAGTrainer<InputType,unsigned int>
trainer(&loss,alpha);
18
trainer.
setEpochs
(epochs);
19
20
Timer
time;
21
trainer.
train
(model, data);
22
double
time_taken = time.
stop
();
23
24
cout <<
"Cross-Entropy: "
<< loss(data.
labels
(),model.
decisionFunction
()(data.
inputs
()))<<std::endl;
25
cout <<
"Time:\n"
<< time_taken << endl;
26
}
27
int
main
(
int
argc,
char
**argv) {
28
ClassificationDataset
data_dense;
29
importSparseData
(data_dense,
"mnist"
,0,8192);
30
data_dense =
transformLabels
(data_dense, [](
unsigned
int
y){
return
y%2;});
31
LabeledData<CompressedRealVector,unsigned int>
data_sparse;
32
importSparseData
(data_sparse,
"rcv1_train.binary"
,0,8192);
33
34
double
alpha = 0.1;
35
run
(data_dense, alpha, 200);
36
run
(data_sparse, alpha, 2000);
37
38
}