Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
examples
Benchmark
shark
logistic_regression_LBFGS.cpp
Go to the documentation of this file.
1
#include <
shark/Data/SparseData.h
>
2
#include <
shark/ObjectiveFunctions/Loss/CrossEntropy.h
>
3
#include <
shark/ObjectiveFunctions/Regularizer.h
>
4
5
#include <
shark/Algorithms/GradientDescent/LBFGS.h
>
6
#include <
shark/ObjectiveFunctions/ErrorFunction.h
>
7
#include <
shark/Models/LinearModel.h
>
8
9
#include <
shark/Core/Timer.h
>
10
#include <iostream>
11
using namespace
shark
;
12
using namespace
std
;
13
14
int
main
(
int
argc,
char
**argv) {
15
ClassificationDataset
data;
16
importSparseData
(data,
"mnist"
,0,8192);
17
double
alpha = 0.1;
18
CrossEntropy
loss;
19
LinearClassifier<>
model;
20
21
//Setting up the problem
22
model.
decisionFunction
().setStructure(
inputDimension
(data),
numberOfClasses
(data),
true
);
23
TwoNormRegularizer
regularizer;
24
ErrorFunction
error(data,&model.
decisionFunction
(),&loss);
25
error.
setRegularizer
(alpha,®ularizer);
26
27
//solving
28
Timer
time;
29
LBFGS
optimizer;
30
optimizer.
init
(error);
31
while
(error.evaluationCounter()<200){
32
optimizer.
step
(error);
33
}
34
model.
setParameterVector
(optimizer.
solution
().
point
);
35
double
time_taken = time.
stop
();
36
37
cout <<
"Cross-Entropy: "
<< loss(data.
labels
(),model.
decisionFunction
()(data.
inputs
()))<<std::endl;
38
cout <<
"Time:\n"
<< time_taken << endl;
39
}