MNIST.h
Go to the documentation of this file.
1 /*!
2  * \brief Loads the MNIST benchmark problem.
3  *
4  * \author O. Krause, A.Fischer, K.Bruegge
5  * \date 2012
6  *
7  *
8  * \par Copyright 1995-2017 Shark Development Team
9  *
10  * <BR><HR>
11  * This file is part of Shark.
12  * <http://shark-ml.org/>
13  *
14  * Shark is free software: you can redistribute it and/or modify
15  * it under the terms of the GNU Lesser General Public License as published
16  * by the Free Software Foundation, either version 3 of the License, or
17  * (at your option) any later version.
18  *
19  * Shark is distributed in the hope that it will be useful,
20  * but WITHOUT ANY WARRANTY; without even the implied warranty of
21  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
22  * GNU Lesser General Public License for more details.
23  *
24  * You should have received a copy of the GNU Lesser General Public License
25  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
26  *
27  */
28 #ifndef UNSUPERVISED_RBM_PROBLEMS_MNIST_H
29 #define UNSUPERVISED_RBM_PROBLEMS_MNIST_H
30 
31 #include <shark/Data/Dataset.h>
32 #include <shark/LinAlg/Base.h>
33 #include <shark/Core/Random.h>
34 
35 #include <sstream>
36 #include <fstream>
37 #include <string>
38 namespace shark{
39 
40 /// \brief Reads in the famous MNIST data in possibly binarized form. The MNIST database itself is not included in Shark,
41 /// this class just helps loading it.
42 ///
43 ///MNIST is a set of handwritten digits.
44 ///It needs the filename of the file containing the database (can be downloaded form the web)
45 ///and the threshold for binarization. The threshold (between 0 and 255) describes when a gray value will be interpreted
46 ///as 1. Default is 127. If the threshold is 0, no binarization takes place.
47 class MNIST{
48 private:
50  std::string m_filename;
51  char m_threshold;
52  std::size_t m_batchSize;
53 
54  int readInt (unsigned char *memblock) const{
55  return ((int)memblock[0] << 24) + ((int)memblock[1] << 16) + ((int)memblock[2] << 8) + memblock[3];
56  }
57  void init(){
58  //m_name="MNIST";
59  std::ifstream infile(m_filename.c_str(), std::ios::binary);
60  SHARK_RUNTIME_CHECK(infile, "Can not open file!");
61 
62  //get file size
63  infile.seekg(0,std::ios::end);
64  std::ifstream::pos_type inputSize = infile.tellg();
65 
66 
67  unsigned char *memblock = new unsigned char [inputSize];
68  infile.seekg (0, std::ios::beg);
69  infile.read ((char *) memblock, inputSize);
70 
71  SHARK_RUNTIME_CHECK(readInt(memblock) == 2051, "magic number for mnist wrong!");
72  std::size_t numImages = readInt(memblock + 4);
73  std::size_t numRows = readInt(memblock + 8);
74  std::size_t numColumns = readInt(memblock + 12);
75  std::size_t sizeOfVis = numRows * numColumns;
76 
77  std::vector<RealVector> data(numImages,RealVector(sizeOfVis));
78  for (std::size_t i = 0; i != numImages; ++i){
79  RealVector imgVec(sizeOfVis);
80  if(m_threshold != 0){
81  for (size_t j = 0; j != sizeOfVis; ++j){
82  char pixel = memblock[ 16 + i * sizeOfVis + j ] > m_threshold;
83  data[i](j) = pixel;
84  }
85  }
86  else{
87  for (size_t j = 0; j != sizeOfVis; ++j){
88  data[i](j) = memblock[ 16 + i * sizeOfVis + j ];
89  }
90  }
91  }
92  delete [] memblock;
93  m_data = createDataFromRange(data,m_batchSize);
94  }
95 public:
96 
97  //Constructor. Sets the configurations from a property tree and imports the data set.
98  //@param filename the name of the file storing the dataset
99  //@param threshhold the threshold for turning gray values into ones
100  //@param batchSize the size of the batch
101  MNIST(std::string filename, char threshold = 127, std::size_t batchSize = 256)
102  : m_filename(filename), m_threshold(threshold), m_batchSize(batchSize){
103  init();
104  }
105 
106  //Returns the data vector
108  return m_data;
109  }
110 
111  //Returns the dimension of the pattern of MNIST.
112  std::size_t inputDimension() const {
113  return 28*28;
114  }
115 
116  //Returns the batch size.
117  std::size_t batchSize() const {
118  return m_batchSize;
119  }
120 
121 };
122 }
123 #endif
124