linearRegressionTutorial.cpp
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Linear Regression Tutorial Sample Code
6  *
7  * This file is part of the "Linear Regression" tutorial.
8  * It requires some toy sample data that comes with the library.
9  *
10  *
11  *
12  * \author C. Igel
13  * \date 2011
14  *
15  *
16  * \par Copyright 1995-2017 Shark Development Team
17  *
18  * <BR><HR>
19  * This file is part of Shark.
20  * <http://shark-ml.org/>
21  *
22  * Shark is free software: you can redistribute it and/or modify
23  * it under the terms of the GNU Lesser General Public License as published
24  * by the Free Software Foundation, either version 3 of the License, or
25  * (at your option) any later version.
26  *
27  * Shark is distributed in the hope that it will be useful,
28  * but WITHOUT ANY WARRANTY; without even the implied warranty of
29  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
30  * GNU Lesser General Public License for more details.
31  *
32  * You should have received a copy of the GNU Lesser General Public License
33  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
34  *
35  */
36 //===========================================================================
37 
38 #include <shark/Data/Csv.h>
41 
42 #include <iostream>
43 
44 using namespace shark;
45 using namespace std;
46 
47 int main(int argc, char **argv) {
48  if(argc < 3) {
49  cerr << "usage: " << argv[0] << " (file with inputs/independent variables) (file with outputs/dependent variables)" << endl;
50  exit(EXIT_FAILURE);
51  }
52  Data<RealVector> inputs;
53  Data<RealVector> labels;
54  try {
55  importCSV(inputs, argv[1], ' ');
56  }
57  catch (...) {
58  cerr << "unable to read input data from file " << argv[1] << endl;
59  exit(EXIT_FAILURE);
60  }
61 
62  try {
63  importCSV(labels, argv[2]);
64  }
65  catch (...) {
66  cerr << "unable to read labels from file " << argv[2] << endl;
67  exit(EXIT_FAILURE);
68  }
69 
70  RegressionDataset data(inputs, labels);
71 
72 
73 
74  // trainer and model
75  LinearRegression trainer;
76  LinearModel<> model;
77 
78  // train model
79  trainer.train(model, data);
80 
81  // show model parameters
82  cout << "intercept: " << model.offset() << endl;
83  cout << "matrix: " << model.matrix() << endl;
84 
85  SquaredLoss<> loss;
86  Data<RealVector> prediction = model(data.inputs());
87  cout << "squared loss: " << loss(data.labels(), prediction) << endl;
88 }