00001
00034
00035
00036 #ifndef SHARK_ALGORITHMS_NEARESTNEIGHBORS_NEARESTNEIGHBORS_H
00037 #define SHARK_ALGORITHMS_NEARESTNEIGHBORS_NEARESTNEIGHBORS_H
00038
00039
00040 #include <boost/intrusive/rbtree.hpp>
00041 #include <shark/Models/Kernels/AbstractKernelFunction.h>
00042 #include <shark/Models/Trees/KDTree.h>
00043 #include <shark/Models/Trees/LCTree.h>
00044 #include <shark/Models/Trees/KHCTree.h>
00045
00046
00047 namespace shark {
00048
00049
00082 template <class InputT>
00083 class IterativeNNQuery
00084 {
00085 typedef BinaryTree<InputT> TreeType;
00086 typedef AbstractKernelFunction<InputT> KernelType;
00087
00088 public:
00093 IterativeNNQuery(const TreeType* tree, InputT const& point, const KernelType* kernel = NULL)
00094 : m_reference(point)
00095 , m_nextIndex(0)
00096 , mp_trace(NULL)
00097 , mep_head(NULL)
00098 , m_squaredRadius(0.0)
00099 , m_neighbors(0)
00100 , m_kernel(kernel)
00101 {
00102
00103
00104
00105 mp_trace = new TraceNode(tree, NULL, m_reference);
00106 TraceNode* tn = mp_trace;
00107 while (tree->hasChildren())
00108 {
00109 if (tree->left()->hasChildren())
00110 tn->mep_left = new TraceNode(tree->left(), tn, m_reference);
00111 else
00112 tn->mep_left = new TraceLeaf(tree->left(), m_kernel, tn, m_reference);
00113 if (tree->right()->hasChildren())
00114 tn->mep_right = new TraceNode(tree->right(), tn, m_reference);
00115 else
00116 tn->mep_right = new TraceLeaf(tree->right(), m_kernel, tn, m_reference);
00117 bool left = tree->isLeft(m_reference);
00118 tn = (left ? tn->mep_left : tn->mep_right);
00119 tree = (left ? tree->left() : tree->right());
00120 }
00121 mep_head = tn->mep_parent;
00122 insertIntoQueue((TraceLeaf*)tn);
00123 m_squaredRadius = mp_trace->squaredRadius(m_reference);
00124 }
00125
00127 ~IterativeNNQuery()
00128 { delete mp_trace; }
00129
00130
00132 std::size_t neighbors() const
00133 { return m_neighbors; }
00134
00136 std::size_t next(double& dist, const InputT*& point)
00137 {
00138 if (m_neighbors >= mp_trace->m_tree->size()) throw SHARKEXCEPTION("[IterativeNNQuery::next] no more neighbors available");
00139
00140 assert(! m_queue.empty());
00141
00142
00143
00144 if (m_neighbors > 0)
00145 {
00146 TraceLeaf& q = *m_queue.begin();
00147 if (m_nextIndex < q.m_tree->size())
00148 {
00149 m_neighbors++;
00150 std::size_t index = q.m_tree->index(m_nextIndex);
00151 dist = std::sqrt(q.m_squaredPtDistance);
00152 point = &q.m_tree->point(m_nextIndex);
00153 m_nextIndex++;
00154 return index;
00155 }
00156 else m_queue.erase(q);
00157 }
00158
00159 if (m_queue.empty() || (*m_queue.begin()).m_squaredPtDistance > m_squaredRadius)
00160 {
00161
00162 TraceNode* tn = mep_head;
00163 while (tn != NULL)
00164 {
00165 enqueue(tn);
00166 if (tn->m_status == COMPLETE) mep_head = tn->mep_parent;
00167 tn = tn->mep_parent;
00168 }
00169
00170
00171 m_squaredRadius = mp_trace->squaredRadius(m_reference);
00172 }
00173
00174 m_neighbors++;
00175 TraceLeaf& q = *m_queue.begin();
00176 std::size_t index = q.m_tree->index(0);
00177 dist = std::sqrt(q.m_squaredPtDistance);
00178 point = &q.m_tree->point(0);
00179 m_nextIndex = 1;
00180 return index;
00181 }
00182
00186 std::size_t queuesize() const
00187 { return m_queue.size(); }
00188
00189 private:
00191 enum Status
00192 {
00193 NONE,
00194 PARTIAL,
00195 COMPLETE,
00196 };
00197
00203 class TraceNode
00204 {
00205 public:
00207 TraceNode(const TreeType* tree, TraceNode* parent, const InputT& ref)
00208 : m_tree(tree)
00209 , m_status(NONE)
00210 , mep_parent(parent)
00211 , mep_left(NULL)
00212 , mep_right(NULL)
00213 , m_squaredDistance(tree->squaredDistanceLowerBound(ref))
00214 { }
00215
00217 ~TraceNode()
00218 {
00219 if (mep_left != NULL) delete mep_left;
00220 if (mep_right != NULL) delete mep_right;
00221 }
00222
00223
00230 inline double squaredRadius(const InputT& ref) const
00231 {
00232 if (m_status == NONE) return m_squaredDistance;
00233 else if (m_status == PARTIAL)
00234 {
00235 double l = mep_left->squaredRadius(ref);
00236 double r = mep_right->squaredRadius(ref);
00237 return std::min(l, r);
00238 }
00239 else return 1e100;
00240 }
00241
00243 const TreeType* m_tree;
00244
00246 Status m_status;
00247
00249 TraceNode* mep_parent;
00250
00252 TraceNode* mep_left;
00253
00255 TraceNode* mep_right;
00256
00258 double m_squaredDistance;
00259 };
00260
00262 typedef boost::intrusive::set_base_hook<> HookType;
00263
00270 class TraceLeaf : public TraceNode, public HookType
00271 {
00272 public:
00274 TraceLeaf(const TreeType* tree, const KernelType* kernel, TraceNode* parent, const InputT& ref)
00275 : TraceNode(tree, parent, ref)
00276 , m_squaredPtDistance(kernel != NULL ? kernel->featureDistanceSqr((*tree)(0), ref) : shark::distanceSqr((*tree)(0), ref))
00277 { }
00278
00280 ~TraceLeaf() { }
00281
00282
00285 inline bool operator < (TraceLeaf const& rhs) const
00286 {
00287 if (m_squaredPtDistance == rhs.m_squaredPtDistance) return (this->m_tree < rhs.m_tree);
00288 return (m_squaredPtDistance < rhs.m_squaredPtDistance);
00289 }
00290
00292 double m_squaredPtDistance;
00293 };
00294
00296 void insertIntoQueue(TraceLeaf* leaf)
00297 {
00298 m_queue.insert_unique(*leaf);
00299
00300
00301 TraceNode* tn = leaf;
00302 tn->m_status = COMPLETE;
00303 while (true)
00304 {
00305 TraceNode* par = tn->mep_parent;
00306 if (par == NULL) break;
00307 if (par->m_status == NONE)
00308 {
00309 par->m_status = PARTIAL;
00310 break;
00311 }
00312 else if (par->m_status == PARTIAL)
00313 {
00314 if (par->mep_left == tn)
00315 {
00316 if (par->mep_right->m_status == COMPLETE) par->m_status = COMPLETE;
00317 else break;
00318 }
00319 else
00320 {
00321 if (par->mep_left->m_status == COMPLETE) par->m_status = COMPLETE;
00322 else break;
00323 }
00324 }
00325 tn = par;
00326 }
00327 }
00328
00332 void enqueue(TraceNode* tn)
00333 {
00334
00335 if (tn->m_status == COMPLETE) return;
00336 if (! m_queue.empty() && tn->m_squaredDistance >= (*m_queue.begin()).m_squaredPtDistance) return;
00337
00338 const TreeType* tree = tn->m_tree;
00339 if (tree->hasChildren())
00340 {
00341
00342 if (tn->mep_left == NULL)
00343 {
00344 if (tree->left()->hasChildren())
00345 tn->mep_left = new TraceNode(tree->left(), tn, m_reference);
00346 else
00347 tn->mep_left = new TraceLeaf(tree->left(), m_kernel, tn, m_reference);
00348 }
00349 if (tn->mep_right == NULL)
00350 {
00351 if (tree->right()->hasChildren())
00352 tn->mep_right = new TraceNode(tree->right(), tn, m_reference);
00353 else
00354 tn->mep_right = new TraceLeaf(tree->right(), m_kernel, tn, m_reference);
00355 }
00356
00357
00358 if (tree->isLeft(m_reference))
00359 {
00360
00361 enqueue(tn->mep_left);
00362 enqueue(tn->mep_right);
00363 }
00364 else
00365 {
00366
00367 enqueue(tn->mep_right);
00368 enqueue(tn->mep_left);
00369 }
00370 }
00371 else
00372 {
00373 TraceLeaf* leaf = (TraceLeaf*)tn;
00374 insertIntoQueue(leaf);
00375 }
00376 }
00377
00379 typedef boost::intrusive::rbtree<TraceLeaf> QueueType;
00380
00382 InputT m_reference;
00383
00385 QueueType m_queue;
00386
00389 std::size_t m_nextIndex;
00390
00392 TraceNode* mp_trace;
00393
00398 TraceNode* mep_head;
00399
00401 double m_squaredRadius;
00402
00404 std::size_t m_neighbors;
00405
00407 const KernelType* m_kernel;
00408 };
00409
00410
00411 }
00412 #endif