35 #ifndef SHARK_MODELS_NEARESTNEIGHBORCLASSIFIER_H 36 #define SHARK_MODELS_NEARESTNEIGHBORCLASSIFIER_H 55 template <
class InputType>
84 {
return "NearestNeighborClassifier"; }
107 RealVector parameters(1);
124 return boost::shared_ptr<State>(
new EmptyState());
130 void eval(BatchInputType
const& patterns, BatchOutputType& output,
State& state)
const{
131 std::size_t numPatterns =
batchSize(patterns);
134 output.resize(numPatterns);
137 for(std::size_t p = 0; p != numPatterns;++p){
138 std::vector<double> histogram(
m_classes, 0.0);
144 if (d < 1e-100) histogram[neighbors[p*
m_neighbors+k].value] += 1e100;
145 else histogram[neighbors[p*
m_neighbors+k].value] += 1.0 / d;
148 output(p) =
static_cast<unsigned int>(std::max_element(histogram.begin(),histogram.end()) - histogram.begin());