TreeNearestNeighbors.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Efficient Nearest neighbor queries.
6  *
7  *
8  *
9  * \author T. Glasmachers
10  * \date 2011
11  *
12  *
13  * \par Copyright 1995-2017 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://shark-ml.org/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_ALGORITHMS_NEARESTNEIGHBORS_TREENEARESTNEIGHBORS_H
36 #define SHARK_ALGORITHMS_NEARESTNEIGHBORS_TREENEARESTNEIGHBORS_H
37 
38 
39 #include <boost/intrusive/rbtree.hpp>
42 #include <shark/Data/DataView.h>
43 namespace shark {
44 
45 
46 ///
47 /// \brief Iterative nearest neighbors query.
48 ///
49 /// \par
50 /// The IterativeNNQuery class (Iterative Nearest Neighbor
51 /// Query) allows the nearest neighbors of a reference point
52 /// to be queried iteratively. Given the reference point, a
53 /// query is set up that returns the nearest neighbor first,
54 /// then the second nearest neighbor, and so on.
55 /// Thus, nearest neighbor queries are treated in an "online"
56 /// fashion. The algorithm follows the paper (generalized to
57 /// arbitrary space-partitioning trees):
58 ///
59 /// \par
60 /// Strategies for efficient incremental nearest neighbor search.
61 /// A. J. Broder. Pattern Recognition 23(1/2), pp 171-178, 1990.
62 ///
63 /// \par
64 /// The algorithm is based on traversing a BinaryTree that
65 /// partitions the space into nested cells. The triangle
66 /// inequality is applied to exclude cells from the search.
67 /// Furthermore, candidate points are cached in a queue,
68 /// such that subsequent queries profit from points that
69 /// could not be excluded this way, but that did not turn
70 /// out the be the (current) nearest neighbor.
71 ///
72 /// \par
73 /// The tree must have a bucket size of one, but leaf nodes
74 /// with multiple copies of the same point are allowed.
75 /// This means that the space partitioning must be carried
76 /// out to the finest possible scale.
77 ///
78 /// The Data must be sotred in a random access container. This means that elements
79 /// have O(1) access time. This is crucial for the performance of the tree lookup.
80 /// When data is stored in a Data<T>, a View should be chosen as template parameter.
81 template <class DataContainer>
83 {
84 public:
85  typedef typename DataContainer::value_type value_type;
88  typedef std::pair<double, std::size_t> result_type;
89 
90  /// create a new query
91  /// \param tree Underlying space-partitioning tree (this is assumed to persist for the lifetime of the query object).
92  /// \param data Container holding the stored data which is referenced by the tree
93  /// \param point Point whose nearest neighbors are to be found.
94  IterativeNNQuery(tree_type const* tree, DataContainer const& data, value_type const& point)
95  : m_data(data)
96  , m_reference(point)
97  , m_nextIndex(0)
98  , mp_trace(NULL)
99  , mep_head(NULL)
100  , m_squaredRadius(0.0)
101  , m_neighbors(0)
102  {
103  // Initialize the recursion trace: descend to the
104  // leaf covering the reference point and queue it.
105  // The parent of this leaf becomes the "head".
106  mp_trace = new TraceNode(tree, NULL, m_reference);
107  TraceNode* tn = mp_trace;
108  while (tree->hasChildren())
109  {
110  tn->createLeftNode(tree, m_data, m_reference);
111  tn->createRightNode(tree, m_data, m_reference);
112  bool left = tree->isLeft(m_reference);
113  tn = (left ? tn->mep_left : tn->mep_right);
114  tree = (left ? tree->left() : tree->right());
115  }
116  mep_head = tn->mep_parent;
117  insertIntoQueue((TraceLeaf*)tn);
118  m_squaredRadius = mp_trace->squaredRadius(m_reference);
119  }
120 
121  /// destroy the query object and its internal data structures
123  m_queue.clear();
124  delete mp_trace;
125  }
126 
127 
128  /// return the number of neighbors already found
129  std::size_t neighbors() const {
130  return m_neighbors;
131  }
132 
133  /// find and return the next nearest neighbor
134  result_type next() {
135  SHARK_RUNTIME_CHECK(m_neighbors < mp_trace->m_tree->size(), "No more neighbors available");
136 
137  assert(! m_queue.empty());
138 
139  // Check whether the current node has points
140  // left, or whether it should be discarded.
141  if (m_neighbors > 0){
142  TraceLeaf& q = *m_queue.begin();
143  if (m_nextIndex < q.m_tree->size()){
144  return getNextPoint(q);
145  }
146  else
147  m_queue.erase(q);
148  }
149  if (m_queue.empty() || (*m_queue.begin()).m_squaredPtDistance > m_squaredRadius){
150  // enqueue more points
151  TraceNode* tn = mep_head;
152  while (tn != NULL){
153  enqueue(tn);
154  if (tn->m_status == COMPLETE) mep_head = tn->mep_parent;
155  tn = tn->mep_parent;
156  }
157 
158  // re-compute the radius
159  m_squaredRadius = mp_trace->squaredRadius(m_reference);
160  }
161  m_nextIndex = 0;
162  ++m_neighbors;
163  return getNextPoint(*m_queue.begin());
164  }
165 
166  /// return the size of the queue,
167  /// which is a measure of the
168  /// overhead of the search
169  std::size_t queuesize() const{
170  return m_queue.size();
171  }
172 
173 private:
174 
175  /// status of a TraceNode object during the search
176  enum Status
177  {
178  NONE, // no points of this node have been queued yet
179  PARTIAL, // some of the points of this node have been queued
180  COMPLETE, // all points of this node have been queued
181  };
182 
183  /// The TraceNode class builds up a tree during
184  /// the search. This tree covers only those parts
185  /// of the space partirioning tree that need to be
186  /// traversed in order to find the next nearest
187  /// neighbor.
188  class TraceNode
189  {
190  public:
191  /// Constructor
192  TraceNode(tree_type const* tree, TraceNode* parent, value_type const& reference)
193  : m_tree(tree)
194  , m_status(NONE)
195  , mep_parent(parent)
196  , mep_left(NULL)
197  , mep_right(NULL)
198  , m_squaredDistance(tree->squaredDistanceLowerBound(reference))
199  { }
200 
201  /// Destructor
202  virtual ~TraceNode()
203  {
204  if (mep_left != NULL) delete mep_left;
205  if (mep_right != NULL) delete mep_right;
206  }
207 
208  void createLeftNode(tree_type const* tree, DataContainer const& data, value_type const& reference){
209  if (tree->left()->hasChildren())
210  mep_left = new TraceNode(tree->left(), this, reference);
211  else
212  mep_left = new TraceLeaf(tree->left(), this, data, reference);
213  }
214  void createRightNode(tree_type const* tree, DataContainer const& data, value_type const& reference){
215  if (tree->right()->hasChildren())
216  mep_right = new TraceNode(tree->right(), this, reference);
217  else
218  mep_right = new TraceLeaf(tree->right(), this, data, reference);
219  }
220 
221  /// Compute the squared distance of the area not
222  /// yet covered by the queue to the reference point.
223  /// This is also referred to as the squared "radius"
224  /// of the area covered by the queue (in fact, it is
225  /// the radius of the largest sphere around the
226  /// reference point that fits into the covered area).
227  double squaredRadius(value_type const& ref) const{
228  if (m_status == NONE) return m_squaredDistance;
229  else if (m_status == PARTIAL)
230  {
231  double l = mep_left->squaredRadius(ref);
232  double r = mep_right->squaredRadius(ref);
233  return std::min(l, r);
234  }
235  else return 1e100;
236  }
237 
238  /// node of the tree
239  tree_type const* m_tree;
240 
241  /// status of the search
242  Status m_status;
243 
244  /// parent node
245  TraceNode* mep_parent;
246 
247  /// "left" child
248  TraceNode* mep_left;
249 
250  /// "right" child
251  TraceNode* mep_right;
252 
253  /// squared distance of the box to the reference point
254  double m_squaredDistance;
255  };
256 
257  /// hook type for intrusive container
258  typedef boost::intrusive::set_base_hook<> HookType;
259 
260  /// Leaves of the three have three roles:
261  /// (1) they are tree nodes holding exactly one point
262  /// (possibly multiple copies of this point),
263  /// (2) they know the distance of their point to the
264  /// reference point,
265  /// (3) they can be added to the candidates queue.
266  class TraceLeaf : public TraceNode, public HookType
267  {
268  public:
269  /// Constructor
270  TraceLeaf(tree_type const* tree, TraceNode* parent, DataContainer const& data, value_type const& ref)
271  : TraceNode(tree, parent, ref){
272  //check whether the tree uses a differen metric than a linear one.
273  if(tree->kernel() != NULL)
274  m_squaredPtDistance = tree->kernel()->featureDistanceSqr(data[tree->index(0)], ref);
275  else
276  m_squaredPtDistance = distanceSqr(data[tree->index(0)], ref);
277  }
278 
279  /// Destructor
280  ~TraceLeaf() { }
281 
282 
283  /// Comparison by distance, ties are broken arbitrarily,
284  /// but deterministically, by tree node pointer.
285  inline bool operator < (TraceLeaf const& rhs) const{
286  if (m_squaredPtDistance == rhs.m_squaredPtDistance)
287  return (this->m_tree < rhs.m_tree);
288  else
289  return (m_squaredPtDistance < rhs.m_squaredPtDistance);
290  }
291 
292  /// Squared distance of the single point in the leaf to the reference point.
293  double m_squaredPtDistance;
294  };
295 
296  /// insert a point into the queue
297  void insertIntoQueue(TraceLeaf* leaf){
298  m_queue.insert_unique(*leaf);
299 
300  // traverse up the tree, updating the state
301  TraceNode* tn = leaf;
302  tn->m_status = COMPLETE;
303  while (true){
304  TraceNode* par = tn->mep_parent;
305  if (par == NULL) break;
306  if (par->m_status == NONE){
307  par->m_status = PARTIAL;
308  break;
309  }
310  else if (par->m_status == PARTIAL){
311  if (par->mep_left == tn){
312  if (par->mep_right->m_status == COMPLETE) par->m_status = COMPLETE;
313  else break;
314  }
315  else{
316  if (par->mep_left->m_status == COMPLETE) par->m_status = COMPLETE;
317  else break;
318  }
319  }
320  tn = par;
321  }
322  }
323 
324  result_type getNextPoint(TraceLeaf const& leaf){
325  double dist = std::sqrt(leaf.m_squaredPtDistance);
326  std::size_t index = leaf.m_tree->index(m_nextIndex);
327  ++m_nextIndex;
328  return std::make_pair(dist,index);
329  }
330 
331  /// Recursively descend the node and enqueue
332  /// all points in cells intersecting the
333  /// current bounding sphere.
334  void enqueue(TraceNode* tn){
335  // check whether this node needs to be enqueued
336  if (tn->m_status == COMPLETE) return;
337  if (! m_queue.empty() && tn->m_squaredDistance >= (*m_queue.begin()).m_squaredPtDistance) return;
338 
339  const tree_type* tree = tn->m_tree;
340  if (tree->hasChildren()){
341  // extend the tree at need
342  if (tn->mep_left == NULL){
343  tn->createLeftNode(tree,m_data,m_reference);
344  }
345  if (tn->mep_right == NULL){
346  tn->createRightNode(tree,m_data,m_reference);
347  }
348 
349  // first descend into the closer sub-tree
350  if (tree->isLeft(m_reference))
351  {
352  // left first
353  enqueue(tn->mep_left);
354  enqueue(tn->mep_right);
355  }
356  else
357  {
358  // right first
359  enqueue(tn->mep_right);
360  enqueue(tn->mep_left);
361  }
362  }
363  else
364  {
365  TraceLeaf* leaf = (TraceLeaf*)tn;
366  insertIntoQueue(leaf);
367  }
368  }
369 
370  /// the queue is a self-balancing tree of sorted entries
371  typedef boost::intrusive::rbtree<TraceLeaf> QueueType;
372 
373 
374  ///\brief Datastorage to lookup the points referenced by the space partitioning tree.
375  DataContainer const& m_data;
376 
377  /// reference point for this query
378  value_type m_reference;
379 
380  /// queue of candidates
381  QueueType m_queue;
382 
383  /// index of the next not yet returned element
384  /// of the current leaf.
385  std::size_t m_nextIndex;
386 
387  /// recursion trace tree
388  TraceNode* mp_trace;
389 
390  /// "head" of the trace tree. This is the
391  /// node containing the reference point,
392  /// but so high up in the tree that it is
393  /// not fully queued yet.
394  TraceNode* mep_head;
395 
396  /// squared radius of the "covered" area
397  double m_squaredRadius;
398 
399  /// number of neighbors already returned
400  std::size_t m_neighbors;
401 };
402 
403 
404 ///\brief Nearest Neighbors implementation using binary trees
405 ///
406 /// Returns the labels and distances of the k nearest neighbors of a point.
407 template<class InputType, class LabelType>
408 class TreeNearestNeighbors:public AbstractNearestNeighbors<InputType,LabelType>
409 {
410 private:
412 
413 public:
418 
419  TreeNearestNeighbors(Dataset const& dataset, Tree const* tree)
420  : m_dataset(dataset)
421  , m_inputs(dataset.inputs())
422  , m_labels(dataset.labels())
423  , mep_tree(tree)
424  {
425  this->m_inputShape = dataset.inputShape();
426  }
427 
428  ///\brief returns the k nearest neighbors of the point
429  std::vector<DistancePair> getNeighbors(BatchInputType const& patterns, std::size_t k)const{
430  std::size_t numPoints = batchSize(patterns);
431  std::vector<DistancePair> results(k*numPoints);
432  for(std::size_t p = 0; p != numPoints; ++p){
433  IterativeNNQuery<DataView<Data<InputType> const> > query(mep_tree, m_inputs, row(patterns, p));
434  //find the neighbors using the queries
435  for(std::size_t i = 0; i != k; ++i){
436  typename IterativeNNQuery<DataView<Data<InputType> const> >::result_type result = query.next();
437  results[i+p*k].key=result.first;
438  results[i+p*k].value= m_labels[result.second];
439  }
440  }
441  return results;
442  }
443 
445  return m_dataset;
446  }
447 
448 private:
449  Dataset const& m_dataset;
450  DataView<Data<InputType> const> m_inputs;
451  DataView<Data<LabelType> const> m_labels;
452  Tree const* mep_tree;
453 
454 };
455 
456 
457 }
458 #endif