include/shark/Algorithms/NearestNeighbors/NearestNeighbors.h
Go to the documentation of this file.
00001 //===========================================================================
00034 //===========================================================================
00035 
00036 #ifndef SHARK_ALGORITHMS_NEARESTNEIGHBORS_NEARESTNEIGHBORS_H
00037 #define SHARK_ALGORITHMS_NEARESTNEIGHBORS_NEARESTNEIGHBORS_H
00038 
00039 
00040 #include <boost/intrusive/rbtree.hpp>
00041 #include <shark/Models/Kernels/AbstractKernelFunction.h>
00042 #include <shark/Models/Trees/KDTree.h>
00043 #include <shark/Models/Trees/LCTree.h>
00044 #include <shark/Models/Trees/KHCTree.h>
00045 
00046 
00047 namespace shark {
00048 
00049 
00082 template <class InputT>
00083 class IterativeNNQuery
00084 {
00085     typedef BinaryTree<InputT> TreeType;
00086     typedef AbstractKernelFunction<InputT> KernelType;
00087 
00088 public:
00093     IterativeNNQuery(const TreeType* tree, InputT const& point, const KernelType* kernel = NULL)
00094     : m_reference(point)
00095     , m_nextIndex(0)
00096     , mp_trace(NULL)
00097     , mep_head(NULL)
00098     , m_squaredRadius(0.0)
00099     , m_neighbors(0)
00100     , m_kernel(kernel)
00101     {
00102         // Initialize the recursion trace: descend to the
00103         // leaf covering the reference point and queue it.
00104         // The parent of this leaf becomes the "head".
00105         mp_trace = new TraceNode(tree, NULL, m_reference);
00106         TraceNode* tn = mp_trace;
00107         while (tree->hasChildren())
00108         {
00109             if (tree->left()->hasChildren())
00110                 tn->mep_left = new TraceNode(tree->left(), tn, m_reference);
00111             else
00112                 tn->mep_left = new TraceLeaf(tree->left(), m_kernel, tn, m_reference);
00113             if (tree->right()->hasChildren())
00114                 tn->mep_right = new TraceNode(tree->right(), tn, m_reference);
00115             else
00116                 tn->mep_right = new TraceLeaf(tree->right(), m_kernel, tn, m_reference);
00117             bool left = tree->isLeft(m_reference);
00118             tn = (left ? tn->mep_left : tn->mep_right);
00119             tree = (left ? tree->left() : tree->right());
00120         }
00121         mep_head = tn->mep_parent;
00122         insertIntoQueue((TraceLeaf*)tn);
00123         m_squaredRadius = mp_trace->squaredRadius(m_reference);
00124     }
00125 
00127     ~IterativeNNQuery()
00128     { delete mp_trace; }
00129 
00130 
00132     std::size_t neighbors() const
00133     { return m_neighbors; }
00134 
00136     std::size_t next(double& dist, const InputT*& point)
00137     {
00138         if (m_neighbors >= mp_trace->m_tree->size()) throw SHARKEXCEPTION("[IterativeNNQuery::next] no more neighbors available");
00139 
00140         assert(! m_queue.empty());
00141 
00142         // Check whether the current node has points
00143         // left, or whether it should be discarded.
00144         if (m_neighbors > 0)
00145         {
00146             TraceLeaf& q = *m_queue.begin();
00147             if (m_nextIndex < q.m_tree->size())
00148             {
00149                 m_neighbors++;
00150                 std::size_t index = q.m_tree->index(m_nextIndex);
00151                 dist = std::sqrt(q.m_squaredPtDistance);
00152                 point = &q.m_tree->point(m_nextIndex);
00153                 m_nextIndex++;
00154                 return index;
00155             }
00156             else m_queue.erase(q);
00157         }
00158 
00159         if (m_queue.empty() || (*m_queue.begin()).m_squaredPtDistance > m_squaredRadius)
00160         {
00161             // enqueue more points
00162             TraceNode* tn = mep_head;
00163             while (tn != NULL)
00164             {
00165                 enqueue(tn);
00166                 if (tn->m_status == COMPLETE) mep_head = tn->mep_parent;
00167                 tn = tn->mep_parent;
00168             }
00169 
00170             // re-compute the radius
00171             m_squaredRadius = mp_trace->squaredRadius(m_reference);
00172         }
00173 
00174         m_neighbors++;
00175         TraceLeaf& q = *m_queue.begin();
00176         std::size_t index = q.m_tree->index(0);
00177         dist = std::sqrt(q.m_squaredPtDistance);
00178         point = &q.m_tree->point(0);
00179         m_nextIndex = 1;
00180         return index;
00181     }
00182 
00186     std::size_t queuesize() const
00187     { return m_queue.size(); }
00188 
00189 private:
00191     enum Status
00192     {
00193         NONE,            // no points of this node have been queued yet
00194         PARTIAL,         // some of the points of this node have been queued
00195         COMPLETE,        // all points of this node have been queued
00196     };
00197 
00203     class TraceNode
00204     {
00205     public:
00207         TraceNode(const TreeType* tree, TraceNode* parent, const InputT& ref)
00208         : m_tree(tree)
00209         , m_status(NONE)
00210         , mep_parent(parent)
00211         , mep_left(NULL)
00212         , mep_right(NULL)
00213         , m_squaredDistance(tree->squaredDistanceLowerBound(ref))
00214         { }
00215 
00217         ~TraceNode()
00218         {
00219             if (mep_left != NULL) delete mep_left;
00220             if (mep_right != NULL) delete mep_right;
00221         }
00222 
00223 
00230         inline double squaredRadius(const InputT& ref) const
00231         {
00232             if (m_status == NONE) return m_squaredDistance;
00233             else if (m_status == PARTIAL)
00234             {
00235                 double l = mep_left->squaredRadius(ref);
00236                 double r = mep_right->squaredRadius(ref);
00237                 return std::min(l, r);
00238             }
00239             else return 1e100;
00240         }
00241 
00243         const TreeType* m_tree;
00244 
00246         Status m_status;
00247 
00249         TraceNode* mep_parent;
00250 
00252         TraceNode* mep_left;
00253 
00255         TraceNode* mep_right;
00256 
00258         double m_squaredDistance;
00259     };
00260 
00262     typedef boost::intrusive::set_base_hook<> HookType;
00263 
00270     class TraceLeaf : public TraceNode, public HookType
00271     {
00272     public:
00274         TraceLeaf(const TreeType* tree, const KernelType* kernel, TraceNode* parent, const InputT& ref)
00275         : TraceNode(tree, parent, ref)
00276         , m_squaredPtDistance(kernel != NULL ? kernel->featureDistanceSqr((*tree)(0), ref) : shark::distanceSqr((*tree)(0), ref))
00277         { }
00278 
00280         ~TraceLeaf() { }
00281 
00282 
00285         inline bool operator < (TraceLeaf const& rhs) const
00286         {
00287             if (m_squaredPtDistance == rhs.m_squaredPtDistance) return (this->m_tree < rhs.m_tree);
00288             return (m_squaredPtDistance < rhs.m_squaredPtDistance);
00289         }
00290 
00292         double m_squaredPtDistance;
00293     };
00294 
00296     void insertIntoQueue(TraceLeaf* leaf)
00297     {
00298         m_queue.insert_unique(*leaf);
00299 
00300         // traverse up the tree, updating the state
00301         TraceNode* tn = leaf;
00302         tn->m_status = COMPLETE;
00303         while (true)
00304         {
00305             TraceNode* par = tn->mep_parent;
00306             if (par == NULL) break;
00307             if (par->m_status == NONE)
00308             {
00309                 par->m_status = PARTIAL;
00310                 break;
00311             }
00312             else if (par->m_status == PARTIAL)
00313             {
00314                 if (par->mep_left == tn)
00315                 {
00316                     if (par->mep_right->m_status == COMPLETE) par->m_status = COMPLETE;
00317                     else break;
00318                 }
00319                 else
00320                 {
00321                     if (par->mep_left->m_status == COMPLETE) par->m_status = COMPLETE;
00322                     else break;
00323                 }
00324             }
00325             tn = par;
00326         }
00327     }
00328 
00332     void enqueue(TraceNode* tn)
00333     {
00334         // check whether this node needs to be enqueued
00335         if (tn->m_status == COMPLETE) return;
00336         if (! m_queue.empty() && tn->m_squaredDistance >= (*m_queue.begin()).m_squaredPtDistance) return;
00337 
00338         const TreeType* tree = tn->m_tree;
00339         if (tree->hasChildren())
00340         {
00341             // extend the tree at need
00342             if (tn->mep_left == NULL)
00343             {
00344                 if (tree->left()->hasChildren())
00345                     tn->mep_left = new TraceNode(tree->left(), tn, m_reference);
00346                 else
00347                     tn->mep_left = new TraceLeaf(tree->left(), m_kernel, tn, m_reference);
00348             }
00349             if (tn->mep_right == NULL)
00350             {
00351                 if (tree->right()->hasChildren())
00352                     tn->mep_right = new TraceNode(tree->right(), tn, m_reference);
00353                 else
00354                     tn->mep_right = new TraceLeaf(tree->right(), m_kernel, tn, m_reference);
00355             }
00356 
00357             // first descend into the closer sub-tree
00358             if (tree->isLeft(m_reference))
00359             {
00360                 // left first
00361                 enqueue(tn->mep_left);
00362                 enqueue(tn->mep_right);
00363             }
00364             else
00365             {
00366                 // right first
00367                 enqueue(tn->mep_right);
00368                 enqueue(tn->mep_left);
00369             }
00370         }
00371         else
00372         {
00373             TraceLeaf* leaf = (TraceLeaf*)tn;
00374             insertIntoQueue(leaf);
00375         }
00376     }
00377 
00379     typedef boost::intrusive::rbtree<TraceLeaf> QueueType;
00380 
00382     InputT m_reference;
00383 
00385     QueueType m_queue;
00386 
00389     std::size_t m_nextIndex;
00390 
00392     TraceNode* mp_trace;
00393 
00398     TraceNode* mep_head;
00399 
00401     double m_squaredRadius;
00402 
00404     std::size_t m_neighbors;
00405 
00407     const KernelType* m_kernel;
00408 };
00409 
00410 
00411 }
00412 #endif