35 #ifndef SHARK_MODELS_TREES_CARTCLASSIFIER_H 36 #define SHARK_MODELS_TREES_CARTCLASSIFIER_H 57 template<
class LabelType>
77 template<
class Archive>
78 void serialize(Archive & ar,
const unsigned int version){
89 NodeInfo() : nodeId(0), attributeIndex(0), attributeValue(0), leftNodeId(0), rightNodeId(0), misclassProp(0), r(0), g(0) {}
91 explicit NodeInfo(std::size_t nodeId) : nodeId(nodeId), attributeIndex(0), attributeValue(0), leftNodeId(0), rightNodeId(0), misclassProp(0), r(0), g(0) {}
93 NodeInfo(std::size_t nodeId, LabelType label) : nodeId(nodeId), attributeIndex(0), attributeValue(0), leftNodeId(0), rightNodeId(0), label(
std::move(label)), misclassProp(0), r(0), g(0) {}
98 : nodeId{n.nodeId}, attributeIndex{n.attributeIndex},
99 attributeValue{n.attributeValue}, leftNodeId{n.leftNodeId},
100 rightNodeId{n.rightNodeId},
label(std::move(n.label)),
101 misclassProp{n.misclassProp}, r{n.r}, g{n.g}
106 attributeIndex = n.attributeIndex;
107 attributeValue = n.attributeValue;
108 leftNodeId = n.leftNodeId;
109 rightNodeId = n.rightNodeId;
110 label = std::move(n.label);
111 misclassProp = n.misclassProp;
151 CARTClassifier(TreeType&& tree, std::size_t d) BOOST_NOEXCEPT_IF((std::is_nothrow_constructible<TreeType,TreeType>::value))
159 {
return "CARTClassifier"; }
162 return boost::shared_ptr<State>(
new EmptyState());
167 void eval(BatchInputType
const& patterns, BatchOutputType & outputs)
const{
168 std::size_t numPatterns = patterns.size1();
175 for(std::size_t i = 0; i != numPatterns; ++i){
180 void eval(BatchInputType
const& patterns, BatchOutputType & outputs,
State& state)
const{
181 eval(patterns,outputs);
184 void eval(RealVector
const& pattern, LabelType& output){
229 typename TreeType::const_iterator it;
232 if(it->leftNodeId != 0) {
233 r(it->attributeIndex)++;
311 double accuracyPermutedOOB = 1. - lossOOB.
eval(pDataOOB.
labels(),pPredOOB);
346 double msePermutedOOB = lossOOB.
eval(pDataOOB.
labels(),pPredOOB);
359 std::size_t index = 0;
360 for(; nodeId != m_tree[index].nodeId; ++index);
369 for(std::size_t i = 0; i < tree.size(); i++){
373 for(std::size_t i = 0; i < tree.size(); i++){
379 template<
class Vector>
385 nodeId = m_tree[
nodeId].leftNodeId;
388 nodeId = m_tree[
nodeId].rightNodeId;
391 return m_tree[
nodeId].label;