Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
examples
Unsupervised
KMeansTutorial.cpp
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief k-means Clustering Tutorial Sample Code, requires the data
6
* set faithful.csv
7
*
8
*
9
*
10
* \author C. Igel
11
* \date 2011
12
*
13
*
14
* \par Copyright 1995-2017 Shark Development Team
15
*
16
* <BR><HR>
17
* This file is part of Shark.
18
* <http://shark-ml.org/>
19
*
20
* Shark is free software: you can redistribute it and/or modify
21
* it under the terms of the GNU Lesser General Public License as published
22
* by the Free Software Foundation, either version 3 of the License, or
23
* (at your option) any later version.
24
*
25
* Shark is distributed in the hope that it will be useful,
26
* but WITHOUT ANY WARRANTY; without even the implied warranty of
27
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
28
* GNU Lesser General Public License for more details.
29
*
30
* You should have received a copy of the GNU Lesser General Public License
31
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
32
*
33
*/
34
//===========================================================================
35
36
#include <
shark/Data/Csv.h
>
//load the csv file
37
#include <
shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h
>
//normalize
38
39
#include <
shark/Algorithms/KMeans.h
>
//k-means algorithm
40
#include <
shark/Models/Clustering/HardClusteringModel.h
>
//model performing hard clustering of points
41
42
#include <iostream>
43
44
using namespace
shark
;
45
using namespace
std
;
46
47
int
main
(
int
argc,
char
**argv) {
48
if
(argc < 2) {
49
cerr <<
"usage: "
<< argv[0] <<
" (filename)"
<< endl;
50
exit(EXIT_FAILURE);
51
}
52
// read data
53
UnlabeledData<RealVector>
data;
54
try
{
55
importCSV
(data, argv[1],
' '
);
56
}
57
catch
(...) {
58
cerr <<
"unable to read data from file "
<< argv[1] << endl;
59
exit(EXIT_FAILURE);
60
}
61
std::size_t elements = data.
numberOfElements
();
62
63
// write statistics of input data
64
cout <<
"number of data points: "
<< elements <<
" dimensions: "
<<
dataDimension
(data) << endl;
65
66
// normalize data
67
Normalizer<>
normalizer;
68
NormalizeComponentsUnitVariance<>
normalizingTrainer(
true
);
//zero mean
69
normalizingTrainer.
train
(normalizer, data);
70
data = normalizer(data);
71
72
// compute centroids using k-means clustering
73
Centroids
centroids;
74
size_t
iterations =
kMeans
(data, 2, centroids);
75
// report number of iterations by the clustering algorithm
76
cout <<
"iterations: "
<< iterations << endl;
77
78
// write cluster centers/centroids
79
Data<RealVector>
const
& c = centroids.
centroids
();
80
cout<<c<<std::endl;
81
82
// cluster data
83
HardClusteringModel<RealVector>
model(¢roids);
84
Data<unsigned>
clusters = model(data);
85
86
// write results to files
87
ofstream c1(
"cl1.csv"
);
88
ofstream c2(
"cl2.csv"
);
89
ofstream cc(
"clc.csv"
);
90
for
(std::size_t i=0; i != elements; i++) {
91
if
(clusters.
element
(i))
92
c1 << data.
element
(i)(0) <<
" "
<< data.
element
(i)(1) << endl;
93
else
94
c2 << data.
element
(i)(0) <<
" "
<< data.
element
(i)(1) << endl;
95
}
96
cc << c.
element
(0)(0) <<
" "
<< c.
element
(0)(1) << endl;
97
cc << c.
element
(1)(0) <<
" "
<< c.
element
(1)(1) << endl;
98
}