SimpleNearestNeighbors.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Efficient brute force implementation of nearest neighbors.
6  *
7  *
8  *
9  * \author O.Krause
10  * \date 2012
11  *
12  *
13  * \par Copyright 1995-2017 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://shark-ml.org/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_ALGORITHMS_NEARESTNEIGHBORS_SIMPLENEARESTNEIGHBORS_H
36 #define SHARK_ALGORITHMS_NEARESTNEIGHBORS_SIMPLENEARESTNEIGHBORS_H
37 
40 #include <shark/Core/OpenMP.h>
41 #include <algorithm>
42 
43 
44 namespace shark {
45 
46 ///\brief Brute force optimized nearest neighbor implementation
47 ///
48 ///Returns the labels and distances of the k nearest neighbors of a point
49 /// The distance is measured using an arbitrary metric
50 template<class InputType, class LabelType>
51 class SimpleNearestNeighbors:public AbstractNearestNeighbors<InputType,LabelType>{
52 private:
54 public:
59 
60  /// \brief Constructor.
61  ///
62  /// \par Construct a "brute force" nearest neighbors search algorithm
63  /// from data and a metric. Refer to the AbstractMetric class for details.
64  /// The "default" Euclidean metric is realized by providing a pointer to
65  /// an object of type LinearKernel<InputType>.
66  SimpleNearestNeighbors(Dataset const& dataset, Metric const* metric)
67  :m_dataset(dataset), mep_metric(metric){
68  this->m_inputShape=dataset.inputShape();
69  }
70 
71  ///\brief Return the k nearest neighbors of the query point.
72  std::vector<DistancePair> getNeighbors(BatchInputType const& patterns, std::size_t k)const{
73  std::size_t numPatterns = batchSize(patterns);
74  std::size_t maxThreads = std::min(SHARK_NUM_THREADS,m_dataset.numberOfBatches());
75  //heaps of key value pairs (distance,classlabel). One heap for every pattern and thread.
76  //For memory alignment reasons, all heaps are stored in one continuous array
77  //the heaps are stored such, that for every pattern the heaps for every thread
78  //are forming one memory area. so later we can just merge all 4 heaps using make_heap
79  //be aware that the values created here allready form a heap since they are all
80  //identical maximum distance.
81  std::vector<DistancePair> heaps(k*numPatterns*maxThreads,DistancePair(std::numeric_limits<double>::max(),LabelType()));
82  typedef typename std::vector<DistancePair>::iterator iterator;
83  //iterate over all batches of the training set in parallel and let
84  //every thread do a KNN-Search on it's subset of data
85  SHARK_PARALLEL_FOR(int b = 0; b < (int)m_dataset.numberOfBatches(); ++b){
86  //evaluate distances between the points of the patterns and the batch
87  RealMatrix distances=mep_metric->featureDistanceSqr(patterns,m_dataset.batch(b).input);
88 
89  //now update the heaps with the distances
90  for(std::size_t p = 0; p != numPatterns; ++p){
91  std::size_t batchSize = distances.size2();
92 
93  //get current heap
94  std::size_t heap = p*maxThreads+SHARK_THREAD_NUM;
95  iterator heapStart=heaps.begin()+heap*k;
96  iterator heapEnd=heapStart+k;
97  iterator biggest=heapEnd-1;//position of biggest element
98 
99  //update heap values using the new distances
100  for(std::size_t i = 0; i != batchSize; ++i){
101  if(biggest->key >= distances(p,i)){
102  //push the smaller neighbor in the heap and replace the biggest one
103  biggest->key=distances(p,i);
104  biggest->value=getBatchElement(m_dataset.batch(b).label,i);
105  std::push_heap(heapStart,heapEnd);
106  //pop biggest element, so that
107  //biggest is again the biggest element
108  std::pop_heap(heapStart,heapEnd);
109  }
110  }
111  }
112  }
113  std::vector<DistancePair> results(k*numPatterns);
114  //finally, we merge all threads in one heap which has the inverse ordering
115  //and create a class histogram over the smallest k neighbors
116  //std::cout<<"info "<<numPatterns<<" "<<maxThreads<<" "<<k<<std::endl;
117  SHARK_PARALLEL_FOR(int p = 0; p < (int)numPatterns; ++p){
118  //find range of the heaps for all threads
119  iterator heapStart=heaps.begin()+p*maxThreads*k;
120  iterator heapEnd=heapStart+maxThreads*k;
121  iterator neighborEnd=heapEnd-k;
122  iterator smallest=heapEnd-1;//position of biggest element
123  //create one single heap of the range with inverse ordering
124  //takes O(maxThreads*k)
125  std::make_heap(heapStart,heapEnd,std::greater<DistancePair>());
126 
127  //create histogram from the neighbors
128  for(std::size_t i = 0;heapEnd!=neighborEnd;--heapEnd,--smallest,++i){
129  std::pop_heap(heapStart,heapEnd,std::greater<DistancePair>());
130  results[i+p*k].key = smallest->key;
131  results[i+p*k].value = smallest->value;
132  }
133  }
134  return results;
135  }
136 
137  /// \brief Direct access to the underlying data set of nearest neighbor points.
139  return m_dataset;
140  }
141 
142 private:
143  Dataset m_dataset; ///< data set of nearest neighbor points
144  Metric const* mep_metric; ///< metric for measuring distances, usually given by a kernel function
145 };
146 
147 
148 }
149 #endif