NearestNeighborRegression.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Nearest neighbor regression
6  *
7  *
8  *
9  * \author T. Glasmachers, O.Krause
10  * \date 2012
11  *
12  *
13  * \par Copyright 1995-2017 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://shark-ml.org/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_MODELS_NEARESTNEIGHBORREGRESSION_H
36 #define SHARK_MODELS_NEARESTNEIGHBORREGRESSION_H
37 
38 
41 namespace shark {
42 
43 
44 /// \brief Nearest neighbor regression model.
45 ///
46 /// The NearestNeighborClassifier predicts a class label
47 /// according to a local majority decision among its k
48 /// nearest neighbors. It is not specified how ties are
49 /// broken.
50 ///
51 /// \tparam InputType Type of input data
52 /// \tparam tree_type Type of binary tree for nearest neighbor search. See KDTree and LCTree for Euclidean distance, and KHCTree for kernel distance.
53 template <class InputType>
54 class NearestNeighborRegression : public AbstractModel<InputType, RealVector>
55 {
56 public:
57 
62 
63  /// \brief Type of distance-based weights.
65  {
66  UNIFORM, ///< uniform (= no) distance-based weights
67  ONE_OVER_DISTANCE, ///< weight each neighbor's label with 1/distance
68  };
69 
70  ///\brief Constructor
71  ///
72  /// \param algorithm the used algorithm for nearst neighbor search
73  /// \param neighbors: number of neighbors
74  NearestNeighborRegression(NearestNeighbors const* algorithm, unsigned int neighbors = 3)
75  : m_algorithm(algorithm)
78  {}
79 
80  /// \brief From INameable: return the class name.
81  std::string name() const
82  { return "NearestNeighborRegression"; }
83 
84 
85  /// return the number of neighbors
86  unsigned int neighbors() const{
87  return m_neighbors;
88  }
89 
90  /// set the number of neighbors
91  void setNeighbors(unsigned int neighbors){
93  }
94 
95  /// query the way distances enter as weights
97  { return m_distanceWeights; }
98 
99  /// set the way distances enter as weights
101  { m_distanceWeights = dw; }
102 
103  /// get internal parameters of the model
104  virtual RealVector parameterVector() const{
105  RealVector parameters(1);
106  parameters(0) = m_neighbors;
107  return parameters;
108  }
109 
110  /// set internal parameters of the model
111  virtual void setParameterVector(RealVector const& newParameters){
112  SHARK_RUNTIME_CHECK(newParameters.size() == 1,
113  "[SoftNearestNeighborClassifier::setParameterVector] invalid number of parameters");
114  //~ SHARK_RUNTIME_CHECK((unsigned int)newParameters(0) == newParameters(0) && newParameters(0) >= 1.0,
115  //~ "[SoftNearestNeighborClassifier::setParameterVector] invalid number of neighbors");
116  m_neighbors = (unsigned int)newParameters(0);
117  }
118 
119  /// return the size of the parameter vector
120  virtual std::size_t numberOfParameters() const{
121  return 1;
122  }
123 
124  boost::shared_ptr<State> createState()const{
125  return boost::shared_ptr<State>(new EmptyState());
126  }
127 
128  using base_type::eval;
129 
130  /// soft k-nearest-neighbor prediction
131  void eval(BatchInputType const& patterns, BatchOutputType& output, State& state)const{
132  std::size_t numPatterns = batchSize(patterns);
133  std::vector<typename NearestNeighbors::DistancePair> neighbors = m_algorithm->getNeighbors(patterns,m_neighbors);
134 
135  std::size_t dimension = neighbors[0].value.size();
136  output.resize(numPatterns,dimension);
137  output.clear();
138 
139  for(std::size_t p = 0; p != numPatterns;++p)
140  {
141  double wsum = 0.0;
142  for ( std::size_t k = 0; k != m_neighbors; ++k)
143  {
144  double w;
145  if (m_distanceWeights == UNIFORM) w = 1.0;
146  else
147  {
148  double d = neighbors[p*m_neighbors+k].key;
149  if (d < 1e-100) w = 1e100;
150  else w = 1.0 / d;
151  }
152  noalias(row(output,p)) += w * neighbors[k+p*m_neighbors].value;
153  wsum += w;
154  }
155  row(output,p) *= (1.0 / wsum);
156  }
157  }
158 
159  /// from ISerializable, reads a model from an archive
160  void read(InArchive& archive){
161  archive & m_neighbors;
162  archive & m_classes;
163  }
164 
165  /// from ISerializable, writes a model to an archive
166  void write(OutArchive& archive) const{
167  archive & m_neighbors;
168  archive & m_classes;
169  }
170 
171 protected:
172  NearestNeighbors const* m_algorithm;
173 
174  /// number of classes
175  unsigned int m_classes;
176 
177  /// number of neighbors to be taken into account
178  unsigned int m_neighbors;
179 
180  /// type of distance-based weights computation
182 };
183 
184 
185 }
186 #endif