NearestNeighborClassifier.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Nearest neighbor classification
6  *
7  *
8  *
9  * \author T. Glasmachers, O. Krause
10  * \date 2012
11  *
12  *
13  * \par Copyright 1995-2017 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://shark-ml.org/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_MODELS_NEARESTNEIGHBORCLASSIFIER_H
36 #define SHARK_MODELS_NEARESTNEIGHBORCLASSIFIER_H
37 
40 #include <algorithm>
41 namespace shark {
42 
43 
44 ///
45 /// \brief Nearest Neighbor Classifier.
46 ///
47 /// \par
48 /// The NearestNeighborClassifier predicts a class label
49 /// according to a local majority decision among its k
50 /// nearest neighbors. It is not specified how ties are
51 /// broken.
52 ///
53 /// This model requires the use of one of sharks nearest neighhbor Algorithms.
54 /// \see AbstractNearestNeighbors
55 template <class InputType>
56 class NearestNeighborClassifier : public AbstractModel<InputType, unsigned int>
57 {
58 public:
63 
64  /// \brief Type of distance-based weights.
66  {
67  UNIFORM, ///< uniform (= no) distance-based weights
68  ONE_OVER_DISTANCE, ///< weight each neighbor's label with 1/distance
69  };
70 
71  ///\brief Constructor
72  ///
73  /// \param algorithm the used algorithm for nearst neighbor search
74  /// \param neighbors: number of neighbors
75  NearestNeighborClassifier(NearestNeighbors const* algorithm, std::size_t neighbors = 3)
76  : m_algorithm(algorithm)
77  , m_classes(numberOfClasses(algorithm->dataset()))
80  { }
81 
82  /// \brief From INameable: return the class name.
83  std::string name() const
84  { return "NearestNeighborClassifier"; }
85 
86 
87  /// return the number of neighbors
88  std::size_t neighbors() const{
89  return m_neighbors;
90  }
91 
92  /// set the number of neighbors
93  void setNeighbors(std::size_t neighbors){
95  }
96 
97  /// query the way distances enter as weights
99  { return m_distanceWeights; }
100 
101  /// set the way distances enter as weights
103  { m_distanceWeights = dw; }
104 
105  /// get internal parameters of the model
106  virtual RealVector parameterVector() const{
107  RealVector parameters(1);
108  parameters(0) = (double)m_neighbors;
109  return parameters;
110  }
111 
112  /// set internal parameters of the model
113  virtual void setParameterVector(RealVector const& newParameters){
114  SHARK_RUNTIME_CHECK(newParameters.size() == 1, "Invalid number of parameters");
115  m_neighbors = (std::size_t)newParameters(0);
116  }
117 
118  /// return the size of the parameter vector
119  virtual std::size_t numberOfParameters() const{
120  return 1;
121  }
122 
123  boost::shared_ptr<State> createState()const{
124  return boost::shared_ptr<State>(new EmptyState());
125  }
126 
127  using base_type::eval;
128 
129  /// soft k-nearest-neighbor prediction
130  void eval(BatchInputType const& patterns, BatchOutputType& output, State& state)const{
131  std::size_t numPatterns = batchSize(patterns);
132  std::vector<typename NearestNeighbors::DistancePair> neighbors = m_algorithm->getNeighbors(patterns,m_neighbors);
133 
134  output.resize(numPatterns);
135  output.clear();
136 
137  for(std::size_t p = 0; p != numPatterns;++p){
138  std::vector<double> histogram(m_classes, 0.0);
139  for ( std::size_t k = 0; k != m_neighbors; ++k){
140  if (m_distanceWeights == UNIFORM) histogram[neighbors[p*m_neighbors+k].value]++;
141  else
142  {
143  double d = neighbors[p*m_neighbors+k].key;
144  if (d < 1e-100) histogram[neighbors[p*m_neighbors+k].value] += 1e100;
145  else histogram[neighbors[p*m_neighbors+k].value] += 1.0 / d;
146  }
147  }
148  output(p) = static_cast<unsigned int>(std::max_element(histogram.begin(),histogram.end()) - histogram.begin());
149  }
150  }
151 
152  /// from ISerializable, reads a model from an archive
153  void read(InArchive& archive){
154  archive & m_neighbors;
155  archive & m_classes;
156  }
157 
158  /// from ISerializable, writes a model to an archive
159  void write(OutArchive& archive) const{
160  archive & m_neighbors;
161  archive & m_classes;
162  }
163 
164 protected:
165  NearestNeighbors const* m_algorithm;
166 
167  /// number of classes
168  std::size_t m_classes;
169 
170  /// number of neighbors to be taken into account
171  std::size_t m_neighbors;
172 
173  /// type of distance-based weights computation
175 };
176 
177 
178 
179 }
180 #endif