NearestNeighborModel.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief NEarest neighbor model for classification and regression
6  *
7  *
8  *
9  * \author T. Glasmachers, C. Igel, O.Krause
10  * \date 2012-2017
11  *
12  *
13  * \par Copyright 1995-2017 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://shark-ml.org/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_MODELS_NEARESTNEIGHBOR_H
36 #define SHARK_MODELS_NEARESTNEIGHBOR_H
37 
38 
42 
43 namespace shark {
44 
45 namespace detail{
46 template <class InputType, class LabelType>
47 class BaseNearestNeighbor : public AbstractModel<InputType, RealVector>
48 {
49 public:
50  typedef AbstractNearestNeighbors<InputType,LabelType> NearestNeighbors;
51  typedef AbstractModel<InputType, RealVector> base_type;
54 
55  /// \brief Constructor
56  ///
57  /// \param algorithm the used algorithm for nearest neighbor search
58  /// \param neighbors number of neighbors
59  BaseNearestNeighbor(NearestNeighbors const* algorithm, std::size_t outputDimensions, unsigned int neighbors = 3)
60  : m_algorithm(algorithm)
61  , m_outputDimensions(outputDimensions)
62  , m_neighbors(neighbors)
63  , m_uniform(true)
64  { }
65 
66  /// \brief From INameable: return the class name.
67  std::string name() const
68  { return "Internal"; }
69 
70  Shape inputShape() const{
71  return m_algorithm->inputShape();
72  }
73  Shape outputShape() const{
74  return Shape(m_outputDimensions);
75  }
76 
77  /// return the number of neighbors
78  unsigned int neighbors() const{
79  return m_neighbors;
80  }
81 
82  /// set the number of neighbors
83  void setNeighbors(unsigned int neighbors){
84  m_neighbors=neighbors;
85  }
86 
87  bool uniformWeights() const{
88  return m_uniform;
89  }
90  bool& uniformWeights(){
91  return m_uniform;
92  }
93 
94  /// get internal parameters of the model
95  virtual RealVector parameterVector() const{
96  RealVector parameters(1);
97  parameters(0) = m_neighbors;
98  return parameters;
99  }
100 
101  /// set internal parameters of the model
102  virtual void setParameterVector(RealVector const& newParameters){
103  SHARK_RUNTIME_CHECK(newParameters.size() == 1,"Invalid number of parameters");
104  m_neighbors = (unsigned int)newParameters(0);
105  }
106 
107  /// return the size of the parameter vector
108  virtual std::size_t numberOfParameters() const{
109  return 1;
110  }
111 
112  boost::shared_ptr<State> createState()const{
113  return boost::shared_ptr<State>(new EmptyState());
114  }
115 
116  /// soft k-nearest-neighbor prediction
117  void eval(BatchInputType const& patterns, BatchOutputType& outputs) const {
118  std::size_t numPatterns = batchSize(patterns);
119  std::vector<typename NearestNeighbors::DistancePair> neighbors = m_algorithm->getNeighbors(patterns, m_neighbors);
120 
121  outputs.resize(numPatterns, m_outputDimensions);
122  outputs.clear();
123 
124  for(std::size_t p = 0; p != numPatterns;++p)
125  {
126  double wsum = 0.0;
127  for ( std::size_t k = 0; k != m_neighbors; ++k)
128  {
129  double w = 1.0;
130  if (!m_uniform){
131  double d = neighbors[p*m_neighbors+k].key;
132  if (d < 1e-100) w = 1e100;
133  else w = 1.0 / d;
134  }
135  updatePrediction(outputs, p, w, neighbors[p*m_neighbors+k].value);
136  wsum += w;
137  }
138  row(outputs, p) /= wsum;
139  }
140  }
141 
142  void eval(BatchInputType const& patterns, BatchOutputType& outputs, State&) const {
143  eval(patterns,outputs);
144  }
145  using base_type::eval;
146 
147  /// from ISerializable, reads a model from an archive
148  void read(InArchive& archive){
149  archive & m_neighbors;
150  archive & m_outputDimensions;
151  archive & m_uniform;
152  }
153 
154  /// from ISerializable, writes a model to an archive
155  void write(OutArchive& archive) const{
156  archive & m_neighbors;
157  archive & m_outputDimensions;
158  archive & m_uniform;
159  }
160 
161 private:
162  void updatePrediction(RealMatrix& outputs, std::size_t p, double w, unsigned int const label) const{
163  outputs(p, label) += w;
164  }
165  template<class T>
166  void updatePrediction(RealMatrix& outputs, std::size_t p, double w, blas::vector<T> const& label)const{
167  noalias(row(outputs,p)) += w * label;
168  }
169  NearestNeighbors const* m_algorithm;
170 
171  /// number of classes
172  std::size_t m_outputDimensions;
173 
174  /// number of neighbors to be taken into account
175  unsigned int m_neighbors;
176 
177  /// type of distance-based weights computation
178  bool m_uniform;
179 };
180 }
181 
182 /// \brief NearestNeighbor model for classification and regression
183 ///
184 /// The classification, the model predicts a class label
185 /// according to a local majority decision among its k
186 /// nearest neighbors. It is not specified how ties are
187 /// broken.
188 ///
189 /// For Regression, the (weighted) mean of the k nearest
190 /// neighbours is computed.
191 template <class InputType, class LabelType>
192 class NearestNeighborModel: public detail::BaseNearestNeighbor<InputType,LabelType>
193 {
194 public:
196  typedef detail::BaseNearestNeighbor<InputType,LabelType> base_type;
197 
198  /// \brief Type of distance-based weights.
200  UNIFORM, ///< uniform (= no) distance-based weights
201  ONE_OVER_DISTANCE, ///< weight each neighbor's label with 1/distance
202  };
203 
204  /// \brief Constructor
205  ///
206  /// \param algorithm the used algorithm for nearest neighbor search
207  /// \param neighbors number of neighbors
208  NearestNeighborModel(NearestNeighbors const* algorithm, unsigned int neighbors = 3)
209  : base_type(algorithm, labelDimension(algorithm->dataset()), neighbors)
210  { }
211 
212  /// \brief From INameable: return the class name.
213  std::string name() const
214  { return "NearestNeighbor"; }
215 
216  /// query the way distances enter as weights
218  return this->decisionFunction().uniformWeights() ? UNIFORM : ONE_OVER_DISTANCE;
219  }
220 
221  /// set the way distances enter as weights
223  this->decisionFunction().uniformWeights() = (dw == UNIFORM);
224  }
225 };
226 
227 
228 template <class InputType>
229 class NearestNeighborModel<InputType, unsigned int>: public Classifier<detail::BaseNearestNeighbor<InputType,unsigned int> >
230 {
231 public:
234 
235  /// \brief Type of distance-based weights.
237  UNIFORM, ///< uniform (= no) distance-based weights
238  ONE_OVER_DISTANCE, ///< weight each neighbor's label with 1/distance
239  };
240 
241  /// \brief Constructor
242  ///
243  /// \param algorithm the used algorithm for nearest neighbor search
244  /// \param neighbors number of neighbors
245  NearestNeighborModel(NearestNeighbors const* algorithm, unsigned int neighbors = 3)
246  : base_type(detail::BaseNearestNeighbor<InputType,unsigned int>(algorithm, numberOfClasses(algorithm->dataset()), neighbors))
247  { }
248 
249  /// \brief From INameable: return the class name.
250  std::string name() const
251  { return "NearestNeighbor"; }
252 
253  /// return the number of neighbors
254  unsigned int neighbors() const{
255  return this->decisionFunction().neighbors();
256  }
257 
258  /// set the number of neighbors
259  void setNeighbors(unsigned int neighbors){
260  this->decisionFunction().setNeighbors(neighbors);
261  }
262 
263  /// query the way distances enter as weights
265  return this->decisionFunction().uniformWeights() ? UNIFORM : ONE_OVER_DISTANCE;
266  }
267 
268  /// set the way distances enter as weights
270  this->decisionFunction().uniformWeights() = (dw == UNIFORM);
271  }
272 };
273 
274 
275 }
276 #endif