35 #ifndef SHARK_ALGORITHMS_NEARESTNEIGHBORS_TREENEARESTNEIGHBORS_H 36 #define SHARK_ALGORITHMS_NEARESTNEIGHBORS_TREENEARESTNEIGHBORS_H 39 #include <boost/intrusive/rbtree.hpp> 81 template <
class DataContainer>
85 typedef typename DataContainer::value_type value_type;
88 typedef std::pair<double, std::size_t> result_type;
94 IterativeNNQuery(tree_type
const* tree, DataContainer
const& data, value_type
const& point)
100 , m_squaredRadius(0.0)
106 mp_trace =
new TraceNode(tree, NULL, m_reference);
107 TraceNode* tn = mp_trace;
110 tn->createLeftNode(tree, m_data, m_reference);
111 tn->createRightNode(tree, m_data, m_reference);
112 bool left = tree->
isLeft(m_reference);
113 tn = (left ? tn->mep_left : tn->mep_right);
114 tree = (left ? tree->
left() : tree->
right());
116 mep_head = tn->mep_parent;
117 insertIntoQueue((TraceLeaf*)tn);
118 m_squaredRadius = mp_trace->squaredRadius(m_reference);
135 SHARK_RUNTIME_CHECK(m_neighbors < mp_trace->m_tree->size(),
"No more neighbors available");
137 assert(! m_queue.empty());
141 if (m_neighbors > 0){
142 TraceLeaf& q = *m_queue.begin();
143 if (m_nextIndex < q.m_tree->size()){
144 return getNextPoint(q);
149 if (m_queue.empty() || (*m_queue.begin()).m_squaredPtDistance > m_squaredRadius){
151 TraceNode* tn = mep_head;
154 if (tn->m_status == COMPLETE) mep_head = tn->mep_parent;
159 m_squaredRadius = mp_trace->squaredRadius(m_reference);
163 return getNextPoint(*m_queue.begin());
170 return m_queue.size();
192 TraceNode(tree_type
const* tree, TraceNode* parent, value_type
const& reference)
204 if (mep_left != NULL)
delete mep_left;
205 if (mep_right != NULL)
delete mep_right;
208 void createLeftNode(tree_type
const* tree, DataContainer
const& data, value_type
const& reference){
210 mep_left =
new TraceNode(tree->
left(),
this, reference);
212 mep_left =
new TraceLeaf(tree->
left(),
this, data, reference);
214 void createRightNode(tree_type
const* tree, DataContainer
const& data, value_type
const& reference){
216 mep_right =
new TraceNode(tree->
right(),
this, reference);
218 mep_right =
new TraceLeaf(tree->
right(),
this, data, reference);
227 double squaredRadius(value_type
const& ref)
const{
228 if (m_status ==
NONE)
return m_squaredDistance;
229 else if (m_status == PARTIAL)
231 double l = mep_left->squaredRadius(ref);
232 double r = mep_right->squaredRadius(ref);
233 return std::min(l, r);
239 tree_type
const* m_tree;
245 TraceNode* mep_parent;
251 TraceNode* mep_right;
254 double m_squaredDistance;
258 typedef boost::intrusive::set_base_hook<> HookType;
266 class TraceLeaf :
public TraceNode,
public HookType
270 TraceLeaf(tree_type
const* tree, TraceNode* parent, DataContainer
const& data, value_type
const& ref)
271 : TraceNode(tree, parent, ref){
273 if(tree->
kernel() != NULL)
274 m_squaredPtDistance = tree->
kernel()->featureDistanceSqr(data[tree->
index(0)], ref);
276 m_squaredPtDistance = distanceSqr(data[tree->
index(0)], ref);
285 inline bool operator < (TraceLeaf
const& rhs)
const{
286 if (m_squaredPtDistance == rhs.m_squaredPtDistance)
287 return (this->m_tree < rhs.m_tree);
289 return (m_squaredPtDistance < rhs.m_squaredPtDistance);
293 double m_squaredPtDistance;
297 void insertIntoQueue(TraceLeaf* leaf){
298 m_queue.insert_unique(*leaf);
301 TraceNode* tn = leaf;
302 tn->m_status = COMPLETE;
304 TraceNode* par = tn->mep_parent;
305 if (par == NULL)
break;
306 if (par->m_status ==
NONE){
307 par->m_status = PARTIAL;
310 else if (par->m_status == PARTIAL){
311 if (par->mep_left == tn){
312 if (par->mep_right->m_status == COMPLETE) par->m_status = COMPLETE;
316 if (par->mep_left->m_status == COMPLETE) par->m_status = COMPLETE;
324 result_type getNextPoint(TraceLeaf
const& leaf){
325 double dist = std::sqrt(leaf.m_squaredPtDistance);
326 std::size_t index = leaf.m_tree->index(m_nextIndex);
328 return std::make_pair(dist,index);
334 void enqueue(TraceNode* tn){
336 if (tn->m_status == COMPLETE)
return;
337 if (! m_queue.empty() && tn->m_squaredDistance >= (*m_queue.begin()).m_squaredPtDistance)
return;
339 const tree_type* tree = tn->m_tree;
342 if (tn->mep_left == NULL){
343 tn->createLeftNode(tree,m_data,m_reference);
345 if (tn->mep_right == NULL){
346 tn->createRightNode(tree,m_data,m_reference);
350 if (tree->
isLeft(m_reference))
353 enqueue(tn->mep_left);
354 enqueue(tn->mep_right);
359 enqueue(tn->mep_right);
360 enqueue(tn->mep_left);
365 TraceLeaf* leaf = (TraceLeaf*)tn;
366 insertIntoQueue(leaf);
371 typedef boost::intrusive::rbtree<TraceLeaf> QueueType;
375 DataContainer
const& m_data;
378 value_type m_reference;
385 std::size_t m_nextIndex;
397 double m_squaredRadius;
400 std::size_t m_neighbors;
407 template<
class InputType,
class LabelType>
421 , m_inputs(dataset.inputs())
422 , m_labels(dataset.labels())
429 std::vector<DistancePair>
getNeighbors(BatchInputType
const& patterns, std::size_t k)
const{
430 std::size_t numPoints =
batchSize(patterns);
431 std::vector<DistancePair> results(k*numPoints);
432 for(std::size_t p = 0; p != numPoints; ++p){
435 for(std::size_t i = 0; i != k; ++i){
437 results[i+p*k].key=result.first;
438 results[i+p*k].value= m_labels[result.second];
449 Dataset
const& m_dataset;
452 Tree
const* mep_tree;