KDTree.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3 *
4 * \brief Tree for nearest neighbor search in low dimensions.
5 *
6 * \author T. Glasmachers
7 * \date 2011
8 *
9 *
10 * <BR><HR>
11 * This file is part of Shark. This library is free software;
12 * you can redistribute it and/or modify it under the terms of the
13 * GNU General Public License as published by the Free Software
14 * Foundation; either version 3, or (at your option) any later version.
15 *
16 * This library is distributed in the hope that it will be useful,
17 * but WITHOUT ANY WARRANTY; without even the implied warranty of
18 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19 * GNU General Public License for more details.
20 *
21 * You should have received a copy of the GNU General Public License
22 * along with this library; if not, see <http://www.gnu.org/licenses/>.
23 *
24 */
25 //===========================================================================
26 
27 #ifndef SHARK_ALGORITHMS_NEARESTNEIGHBORS_KDTREE_H
28 #define SHARK_ALGORITHMS_NEARESTNEIGHBORS_KDTREE_H
29 
30 
32 #include <shark/Data/DataView.h>
33 #include <shark/LinAlg/Base.h>
34 #include <shark/Core/Math.h>
35 namespace shark {
36 
37 
38 ///
39 /// \brief KD-tree, a binary space-partitioning tree
40 ///
41 /// \par
42 /// KD-tree data structure for efficient nearest
43 /// neighbor searches in low-dimensional spaces.
44 /// Low-dimensional means << 10 dimensions.
45 ///
46 /// \par
47 /// The tree is constructed from data by splitting
48 /// the dimension with largest extent (in the data
49 /// covered, not space represented by the cell),
50 /// recursively. An approximate median is used as
51 /// the cutting point, where the maximal number of
52 /// points used to estimate the median can be
53 /// controlled with the template parameter
54 /// MedianAccuracy.
55 ///
56 /// \par
57 /// The bucket size for the tree is aleays one,
58 /// such that there is a unique point in each leaf
59 /// cell.
60 ///
61 template <class InputT>
62 class KDTree : public BinaryTree<InputT>
63 {
64  typedef KDTree<InputT> self_type;
66 public:
67 
68  /// Construct the tree from data.
69  /// It is assumed that the container exceeds
70  /// the lifetime of the KDTree (which holds
71  /// only references to the points), and that
72  /// the memory locations of the points remain
73  /// unchanged.
75  : base_type(dataset.numberOfElements())
76  , m_cutDim(0xffffffff){
77  typedef DataView<Data<RealVector> const> PointSet;
78  PointSet points(dataset);
79  //create a list to the iterator elements as temporary storage
80  std::vector<typename boost::range_iterator<PointSet>::type> elements(m_size);
81  boost::iota(elements,boost::begin(points));
82 
83  buildTree(tc,elements);
84 
85  //after the creation of the trees, the iterators in the array are sorted in the order,
86  //they are referenced by the nodes.so we can create the indexList using the indizes of the iterators
87  for(std::size_t i = 0; i != m_size; ++i){
88  mp_indexList[i] = elements[i].index();
89  }
90  }
91 
92 
93  /// lower bound on the box-shaped
94  /// space represented by this cell
95  double lower(std::size_t dim) const{
96  self_type* parent = static_cast<self_type*>(mep_parent);
97  if (parent == NULL) return -1e100;
98 
99  if (parent->m_cutDim == dim && parent->mp_right == this)
100  return parent->threshold();
101  else
102  return parent->lower(dim);
103  }
104 
105  /// upper bound on the box-shaped
106  /// space represented by this cell
107  double upper(std::size_t dim) const{
108  self_type* parent = static_cast<self_type*>(mep_parent);
109  if (parent == NULL) return +1e100;
110 
111  if (parent->m_cutDim == dim && parent->mp_left == this)
112  return parent->threshold();
113  else
114  return parent->upper(dim);
115  }
116 
117  /// \par
118  /// Compute the squared Euclidean distance of
119  /// this cell to the given reference point, or
120  /// alternatively a lower bound on this value.
121  ///
122  /// \par
123  /// In the case of the kd-tree the computation
124  /// is exact, however, only a lower bound is
125  /// required in general, and other space
126  /// partitioning trees used in the future may
127  /// only be able to provide a lower bound, at
128  /// least with reasonable computational efforts.
129  double squaredDistanceLowerBound(InputT const& reference) const
130  {
131  double ret = 0.0;
132  for (std::size_t d = 0; d != reference.size(); d++)
133  {
134  double v = reference(d);
135  double l = lower(d);
136  double u = upper(d);
137  if (v < l){
138  ret += sqr(l-v);
139  }
140  else if (v > u){
141  ret += sqr(v-u);
142  }
143  }
144  return ret;
145  }
146 
147 #if 0
148  // debug code, please ignore
149  void print(unsigned int ident = 0) const
150  {
151  if (this->isLeaf())
152  {
153  for (unsigned int j=0; j<m_size; j++)
154  {
155  for (unsigned int i=0; i<ident; i++) printf(" ");
156  printf("index: %d\n", (int)this->index(j));
157  }
158  }
159  else
160  {
161  for (unsigned int i=0; i<ident; i++) printf(" ");
162  printf("x[%d] < %g\n", (int)m_cutDim, this->threshold());
163  for (unsigned int i=0; i<ident; i++) printf(" ");
164  printf("[%d]\n", (int)mp_left->size());
165  ((self_type*)mp_left)->print(ident + 1);
166  for (unsigned int i=0; i<ident; i++) printf(" ");
167  printf("[%d]\n", (int)mp_right->size());
168  ((self_type*)mp_right)->print(ident + 1);
169  }
170  }
171 #endif
172 
173 protected:
174  using base_type::mep_parent;
175  using base_type::mp_left;
176  using base_type::mp_right;
178  using base_type::m_size;
179  using base_type::m_nodes;
180 
181  /// (internal) construction of a non-root node
182  KDTree(KDTree* parent, std::size_t* list, std::size_t size)
183  : base_type(parent, list, size)
184  , m_cutDim(0xffffffff)
185  { }
186 
187  /// (internal) construction method:
188  /// median-cuts of the dimension with widest spread
189  template<class Range>
190  void buildTree(TreeConstruction tc, Range& points){
191  typedef typename boost::range_iterator<Range>::type iterator;
192 
193  iterator begin = boost::begin(points);
194  iterator end = boost::end(points);
195 
196  if (tc.maxDepth() == 0 || m_size <= tc.maxBucketSize()){
197  m_nodes = 1;
198  return;
199  }
200 
201  m_cutDim = calculateCuttingDimension(points);
202 
203  // calculate the distance of the boundary for every point in the list
204  std::vector<double> distance(m_size);
205  iterator point = begin;
206  for(std::size_t i = 0; i != m_size; ++i,++point){
207  distance[i] = get(**point,m_cutDim);
208  }
209 
210  // split the list into sub-cells
211  iterator split = this->splitList(distance,points);
212  if (split == end){
213  m_nodes = 1;
214  return;
215  }
216  std::size_t leftSize = split-begin;
217 
218  // create sub-nodes
219  mp_left = new KDTree(this, mp_indexList, leftSize);
220  mp_right = new KDTree(this, mp_indexList + leftSize, m_size - leftSize);
221 
222  // recurse
223  boost::iterator_range<iterator> left(begin,split);
224  boost::iterator_range<iterator> right(split,end);
225  ((KDTree*)mp_left)->buildTree(tc.nextDepthLevel(), left);
226  ((KDTree*)mp_right)->buildTree(tc.nextDepthLevel(), right);
227  m_nodes = 1 + mp_left->nodes() + mp_right->nodes();
228  }
229 
230  ///\brief Calculate the dimension which should be cut by this node
231  ///
232  ///The KD-Tree calculates the Axis-Aligned-Bounding-Box surrounding the points
233  ///and splits the longest dimension
234  template<class Range>
235  std::size_t calculateCuttingDimension(Range const& points)const{
236  typedef typename boost::range_iterator<Range const>::type iterator;
237 
238  iterator begin = boost::begin(points);
239  iterator end = boost::end(points);
240 
241  // calculate bounding box of the data
242  InputT L = **begin;
243  InputT U = **begin;
244  std::size_t dim = L.size();
245  iterator point = begin;
246  ++point;
247  for (std::size_t i=1; i != m_size; ++i,++point){
248  for (std::size_t d = 0; d != dim; d++){
249  double v = (**point)[d];
250  if (v < L[d]) L[d] = v;
251  if (v > U[d]) U[d] = v;
252  }
253  }
254 
255  // find the longest edge of the bounding box
256  std::size_t cutDim = 0;
257  double extent = U[0] - L[0];
258  for (std::size_t d = 1; d != dim; d++)
259  {
260  double e = U[d] - L[d];
261  if (e > extent)
262  {
263  extent = e;
264  cutDim = d;
265  }
266  }
267  return cutDim;
268  }
269 
270  /// Function describing the separating hyperplane
271  /// as its zero set. It is guaranteed that the
272  /// gradient has norm one, thus the absolute value
273  /// of the function indicates distance from the
274  /// hyperplane.
275  double funct(InputT const& reference) const{
276  return reference[m_cutDim];
277  }
278 
279  /// split/cut dimension of this node
280  std::size_t m_cutDim;
281 };
282 
283 
284 }
285 #endif