SoftNearestNeighborClassifier.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Soft/probabilistic nearest neighbor classifier for vector-valued data.
6  *
7  *
8  *
9  * \author T. Glasmachers, C. Igel, O.Krause
10  * \date 2012-2014
11  *
12  *
13  * \par Copyright 1995-2017 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://shark-ml.org/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_MODELS_SOFTNEARESTNEIGHBOR_H
36 #define SHARK_MODELS_SOFTNEARESTNEIGHBOR_H
37 
38 
41 
42 namespace shark {
43 
44 /// \brief SoftNearestNeighborClassifier returns a probabilistic
45 /// classification by looking at the k nearest neighbors.
46 ///
47 /// For a given number C of classes, which has to be specified a
48 /// priori, a C-dimensional real-valued vector is returned for each
49 /// query point. Each component corresponds to a class and contains
50 /// the fraction of neighbors among the K nearest neighbors that
51 /// belong to the particular class.
52 ///
53 template <class InputType>
54 class SoftNearestNeighborClassifier : public AbstractModel<InputType, RealVector>
55 {
56 public:
61 
62  /// \brief Type of distance-based weights.
64  {
65  UNIFORM, ///< uniform (= no) distance-based weights
66  ONE_OVER_DISTANCE, ///< weight each neighbor's label with 1/distance
67  };
68 
69  /// \brief Constructor
70  ///
71  /// \param algorithm the used algorithm for nearest neighbor search
72  /// \param neighbors number of neighbors
73  SoftNearestNeighborClassifier(NearestNeighbors const* algorithm, unsigned int neighbors = 3)
74  : m_algorithm(algorithm)
75  , m_classes(numberOfClasses(algorithm->dataset()))
78  { }
79 
80  /// \brief Constructor
81  ///
82  /// \param algorithm the used algorithm for nearest neighbor search
83  /// \param numClasses number of classes (given explicitly, not derived from the training data)
84  /// \param neighbors number of neighbors
85  SoftNearestNeighborClassifier(NearestNeighbors const* algorithm, std::size_t numClasses, unsigned int neighbors)
86  : m_algorithm(algorithm)
87  , m_classes(numClasses)
88  , m_neighbors(neighbors)
90  { }
91 
92  /// \brief From INameable: return the class name.
93  std::string name() const
94  { return "SoftNearestNeighborClassifier"; }
95 
96 
97  /// return the number of neighbors
98  unsigned int neighbors() const{
99  return m_neighbors;
100  }
101 
102  /// set the number of neighbors
103  void setNeighbors(unsigned int neighbors){
105  }
106 
107  /// query the way distances enter as weights
109  { return m_distanceWeights; }
110 
111  /// set the way distances enter as weights
113  { m_distanceWeights = dw; }
114 
115  /// get internal parameters of the model
116  virtual RealVector parameterVector() const{
117  RealVector parameters(1);
118  parameters(0) = m_neighbors;
119  return parameters;
120  }
121 
122  /// set internal parameters of the model
123  virtual void setParameterVector(RealVector const& newParameters){
124  SHARK_RUNTIME_CHECK(newParameters.size() == 1,"Invalid number of parameters");
125  m_neighbors = (unsigned int)newParameters(0);
126  }
127 
128  /// return the size of the parameter vector
129  virtual std::size_t numberOfParameters() const{
130  return 1;
131  }
132 
133  boost::shared_ptr<State> createState()const{
134  return boost::shared_ptr<State>(new EmptyState());
135  }
136 
137  /// soft k-nearest-neighbor prediction
138  void eval(BatchInputType const& patterns, BatchOutputType& outputs) const {
139  std::size_t numPatterns = batchSize(patterns);
140  std::vector<typename NearestNeighbors::DistancePair> neighbors = m_algorithm->getNeighbors(patterns, m_neighbors);
141 
142  outputs.resize(numPatterns, m_classes);
143  outputs.clear();
144 
145  for(std::size_t p = 0; p != numPatterns;++p)
146  {
147  double wsum = 0.0;
148  for ( std::size_t k = 0; k != m_neighbors; ++k)
149  {
150  double w;
151  if (m_distanceWeights == UNIFORM) w = 1.0;
152  else
153  {
154  double d = neighbors[p*m_neighbors+k].key;
155  if (d < 1e-100) w = 1e100;
156  else w = 1.0 / d;
157  }
158 
159  outputs(p, neighbors[p*m_neighbors+k].value) += w;
160  wsum += w;
161  }
162  row(outputs, p) *= (1.0 / wsum);
163  }
164  }
165  void eval(BatchInputType const& patterns, BatchOutputType& outputs, State & state)const{
166  eval(patterns, outputs);
167  }
168 
169  using base_type::eval;
170 
171  /// from ISerializable, reads a model from an archive
172  void read(InArchive& archive){
173  archive & m_neighbors;
174  archive & m_classes;
175  }
176 
177  /// from ISerializable, writes a model to an archive
178  void write(OutArchive& archive) const{
179  archive & m_neighbors;
180  archive & m_classes;
181  }
182 
183 protected:
184  NearestNeighbors const* m_algorithm;
185 
186  /// number of classes
187  std::size_t m_classes;
188 
189  /// number of neighbors to be taken into account
190  unsigned int m_neighbors;
191 
192  /// type of distance-based weights computation
194 };
195 
196 
197 }
198 #endif