BinaryTree.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Binary space-partitioning tree of data points.
6  *
7  *
8  *
9  * \author T. Glasmachers
10  * \date 2011
11  *
12  *
13  * \par Copyright 1995-2017 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://shark-ml.org/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_ALGORITHMS_NEARESTNEIGHBORS_BINARYTREE_H
36 #define SHARK_ALGORITHMS_NEARESTNEIGHBORS_BINARYTREE_H
37 
38 
39 #include <shark/Core/Exception.h>
43 
44 #include <boost/range/algorithm_ext/iota.hpp>
45 #include <boost/range/iterator_range.hpp>
46 #include <boost/math/special_functions/fpclassify.hpp>
47 namespace shark {
48 
49 
50 /// \brief Stopping criteria for tree construction.
51 ///
52 /// \par
53 /// Conditions for automatic tree construction.
54 /// The data structure allows to specify a maximum
55 /// bucket size (number of instances represented
56 /// by a leaf), and a maximum tree depth.
57 ///
58 /// \par
59 /// Note: If a data instance appears more often in
60 /// a dataset than specified by the maximum bucket
61 /// size then this condition will be violated; this
62 /// is because a space partitioning tree has no
63 /// means of separating a single point.
64 ///
66 {
67 public:
68  /// \brief Default constructor: only stop at trivial leaves
70  : m_maxDepth(0xffffffff)
71  , m_maxBucketSize(1)
72  { }
73 
74  /// \brief Copy constructor.
76  : m_maxDepth(other.m_maxDepth)
78  { }
79 
80  /// \brief Constructor.
81  ///
82  /// \param maxDepth stop as soon as the given tree depth is reached (zero means unrestricted)
83  /// \param maxBucketSize stop as soon as a node holds at most the bucket size of data points (zero means unrestricted)
84  TreeConstruction(unsigned int maxDepth, unsigned int maxBucketSize)
85  : m_maxDepth(maxDepth ? maxDepth : 0xffffffff)
86  , m_maxBucketSize(maxBucketSize ? maxBucketSize : 1)
87  { }
88 
89 
90  /// return a TreeConstruction object with maxDepth reduced by one
93 
94 
95  /// return maximum depth of the tree
96  unsigned int maxDepth() const
97  { return m_maxDepth; }
98 
99  /// return maximum "size" of a leaf node
100  unsigned int maxBucketSize() const
101  { return m_maxBucketSize; }
102 
103 protected:
104  /// maximum depth of the tree
105  unsigned int m_maxDepth;
106 
107  /// maximum "size" of a leaf node
108  unsigned int m_maxBucketSize;
109 };
110 
111 
112 ///
113 /// \brief Super class of binary space-partitioning trees.
114 ///
115 /// \par
116 /// This class represents a generic node in a binary
117 /// space-partitioning tree. At each junction the
118 /// space cell represented by the parent node is
119 /// split into sub-cells by thresholding a real-valued
120 /// function. Different sub-classes implement different
121 /// such functions. The absolute value of the function
122 /// minus the threshold m_threshold must represent the
123 /// distance to the separating hyper-surface in the
124 /// underlying metric. This allows for linear separation,
125 /// but also for kernel-induced feature spaces and other
126 /// embeddings.
127 template <class InputT>
129 {
130 public:
131  typedef InputT value_type;
132 
133  /// \brief Root node constructor: build the tree from data.
134  ///
135  /// Please refer the specific sub-classes such as KDTree
136  /// for examples of how the binary tree is built.
137  ///
138  BinaryTree(std::size_t size)
139  : mep_parent(NULL)
140  , mp_left(NULL)
141  , mp_right(NULL)
142  , mp_indexList(NULL)
143  , m_size(size)
144  , m_nodes(0)
145  , m_threshold(0.0)
146  {
147  SHARK_ASSERT(m_size > 0);
148 
149  // prepare list of index/pointer pairs to be shared among the whole tree
150  mp_indexList = new std::size_t[m_size];
151  std::iota(mp_indexList,mp_indexList+m_size,0);
152  }
153 
154  /// Destroy the tree and its internal data structures
155  virtual ~BinaryTree()
156  {
157  if (mp_left != NULL) delete mp_left;
158  if (mp_right != NULL) delete mp_right;
159  if (mep_parent == NULL) delete [] mp_indexList;
160  }
161 
162 
163  // binary tree structure
164 
165  /// parent node
167  { return mep_parent; }
168  /// parent node
169  const BinaryTree* parent() const
170  { return mep_parent; }
171 
172  /// Does this node have children?
173  /// Opposite of isLeaf()
174  bool hasChildren() const
175  { return (mp_left != NULL); }
176 
177  /// Is this a leaf node?
178  /// Opposite of hasChildren()
179  bool isLeaf() const
180  { return (mp_left == NULL); }
181 
182  /// "left" sub-node of the binary tree
184  { return mp_left; }
185  /// "left" sub-node of the binary tree
186  const BinaryTree* left() const
187  { return mp_left; }
188 
189  /// "right" sub-node of the binary tree
191  { return mp_right; }
192  /// "right" sub-node of the binary tree
193  const BinaryTree* right() const
194  { return mp_right; }
195 
196  /// number of points inside the space represented by this node
197  std::size_t size() const
198  { return m_size; }
199 
200  /// number of sub-nodes in this tree (including the node itself)
201  std::size_t nodes() const
202  { return m_nodes; }
203 
204  std::size_t index(std::size_t point)const{
205  return mp_indexList[point];
206  }
207 
208 
209  // partition represented by this node
210 
211  /// \brief Function describing the separation of space.
212  ///
213  /// \par
214  /// This function is shifted by subtracting the
215  /// threshold from the virtual function "funct" (which
216  /// acts as a "decision" function to split space into
217  /// sub-cells).
218  /// The result of this function describes "signed"
219  /// distance to the separation boundary, and the two
220  /// cells are thresholded at zero. We obtain the two
221  /// cells:<br/>
222  /// left ("negative") cell: {x | distance(x) < 0}<br/>
223  /// right ("positive") cell {x | distance(x) >= 0}
224  double distanceFromPlane(value_type const& point) const{
225  return funct(point) - m_threshold;
226  }
227 
228  /// \brief Separation threshold.
229  double threshold() const{
230  return m_threshold;
231  }
232 
233  /// return true if the reference point belongs to the
234  /// "left" sub-node, or to the "negative" sub-cell.
235  bool isLeft(value_type const& point) const
236  { return (funct(point) < m_threshold); }
237 
238  /// return true if the reference point belongs to the
239  /// "right" sub-node, or to the "positive" sub-cell.
240  bool isRight(value_type const& point) const
241  { return (funct(point) >= m_threshold); }
242 
243  /// \brief If the tree uses a kernel metric, returns a pointer to the kernel object, else NULL.
245  //default is no kernel metric
246  return NULL;
247  }
248 
249 
250  /// \brief Compute lower bound on the squared distance to the space cell.
251  ///
252  /// \par
253  /// For efficient use of the triangle inequality
254  /// the space cell represented by each node needs
255  /// to provide (a lower bound on) the distance of
256  /// a query point to the space cell represented by
257  /// this node. Search is fast if this bound is
258  /// tight.
259  virtual double squaredDistanceLowerBound(value_type const& point) const = 0;
260 
261 protected:
262  /// \brief Sub-node constructor
263  ///
264  /// \par
265  /// Initialize a sub-node
266  BinaryTree(BinaryTree* parent, std::size_t* list, std::size_t size)
267  : mep_parent(parent)
268  , mp_left(NULL)
269  , mp_right(NULL)
270  , mp_indexList(list)
271  , m_size(size)
272  , m_nodes(0)
273  {}
274 
275 
276  /// \brief Function describing the separation of space.
277  ///
278  /// \par
279  /// This is the primary interface for customizing
280  /// sub-classes.
281  ///
282  /// \par
283  /// This function splits the space represented by the
284  /// node by thresholding at m_threshold. The "negative"
285  /// cell, represented in the "left" sub-node, represents
286  /// the space {x | funct(x) < threshold}. The
287  /// "positive" cell, represented by the "right"
288  /// sub-node, represents {x | funct(x) >= threshold}.
289  /// The function needs to be normalized such that
290  /// |f(x) - f(y)| is the distance between
291  /// {z | f(z) = f(x)} and {z | f(z) = f(y)}, w.r.t.
292  /// the distance function also used by the virtual
293  /// function squaredDistanceLowerBound. The exact
294  /// distance function does not matter as long as
295  /// the same distance function is used in both cases.
296  virtual double funct(value_type const& point) const = 0;
297 
298  /// \brief Split the data in the point list and calculate the treshold accordingly
299  ///
300  /// The method computes the optimal threshold given the distance of every element of
301  /// the set and partitions the point range accordingly
302  /// @param values the value of every point returned by funct(point)
303  /// @param points the points themselves
304  /// @returns returns the position were the point list was split
305  template<class Range1, class Range2>
306  typename Range2::iterator splitList (Range1& values, Range2& points){
307  std::vector<KeyValuePair<typename Range1::value_type, typename Range2::value_type> > range(values.size());
308  for(std::size_t i = 0; i != range.size(); ++i){
309  range[i].key = values[i];
310  range[i].value = points[i];
311  }
312 
313 
314  auto pos = partitionEqually(range);
315  for(std::size_t i = 0; i != range.size(); ++i){
316  values[i] = range[i].key;
317  points[i] = range[i].value;
318  }
319 
320  if (pos == range.end()) {
321  // partitioning failed, all values are equal :(
322  m_threshold = values[0];
323  return points.begin();
324  }
325 
326  // We don't want the threshold to be the value of an element but always in between two of them.
327  // This ensures that no point of the training set lies on the boundary. This leeds to more stable
328  // results. So we use the mean of the found splitpoint and the nearest point on the other side
329  // of the boundary.
330  double maximum = std::max_element(range.begin(), pos)->key;
331  m_threshold = 0.5*(maximum + pos->key);
332 
333  return points.begin() + (pos - range.begin());
334  }
335 
336  /// parent node
338 
339  /// "left" child node
341 
342  /// "right" child node
344 
345  /// list of indices to points in the dataset
346  std::size_t* mp_indexList;
347 
348  /// number of points in this node
349  std::size_t m_size;
350 
351  /// number of nodes in the sub-tree represented by this node
352  std::size_t m_nodes;
353 
354  /// threshold for the separating function
355  double m_threshold;
356 
357 };
358 
359 
360 }
361 #endif