33 #ifndef SHARK_MODELS_CLASSIFIER_H 34 #define SHARK_MODELS_CLASSIFIER_H 67 typedef typename Model::BatchOutputType ModelBatchOutputType;
76 : m_decisionFunction(decisionFunction){}
79 {
return "Classifier<"+m_decisionFunction.name()+
">"; }
82 return m_decisionFunction.parameterVector();
86 m_decisionFunction.setParameterVector(newParameters);
90 return m_decisionFunction.numberOfParameters();
95 return m_decisionFunction.inputShape();
113 return m_decisionFunction;
118 return m_decisionFunction;
121 void eval(BatchInputType
const& input, BatchOutputType& output)
const{
122 SIZE_CHECK(m_bias.empty() || m_decisionFunction.outputShape().numElements() == m_bias.size());
123 ModelBatchOutputType modelResult;
124 m_decisionFunction.eval(input,modelResult);
125 std::size_t
batchSize = modelResult.size1();
126 output.resize(batchSize);
127 if(modelResult.size2()== 1){
128 double bias = m_bias.empty()? 0.0 : m_bias(0);
129 for(std::size_t i = 0; i !=
batchSize; ++i){
130 output(i) = modelResult(i,0) + bias > 0.0;
134 for(std::size_t i = 0; i !=
batchSize; ++i){
136 output(i) =
static_cast<unsigned int>(arg_max(row(modelResult,i)));
138 output(i) =
static_cast<unsigned int>(arg_max(row(modelResult,i) + m_bias));
142 void eval(BatchInputType
const& input, BatchOutputType& output,
State& state)
const{
146 void eval(InputType
const & pattern, OutputType& output)
const{
147 SIZE_CHECK(m_bias.empty() || m_decisionFunction.outputShape().numElements() == m_bias.size());
148 typename Model::OutputType modelResult;
149 m_decisionFunction.eval(pattern,modelResult);
151 if(modelResult.size() == 1){
152 double bias = m_bias.empty()? 0.0 : m_bias(0);
153 output = modelResult(0) + bias > 0.0;
157 output = static_cast<unsigned int>(arg_max(modelResult));
159 output =
static_cast<unsigned int>(arg_max(modelResult + m_bias));
166 archive >> m_decisionFunction;
171 archive << m_decisionFunction;
176 Model m_decisionFunction;