35 #ifndef SHARK_MODELS_TREES_CARTree_H 36 #define SHARK_MODELS_TREES_CARTree_H 49 template<
class LabelType>
64 template<
class Archive>
65 void serialize(Archive & ar,
const unsigned int version){
83 return boost::shared_ptr<State>(
new EmptyState());
88 void eval(BatchInputType
const& patterns, BatchOutputType & outputs)
const{
89 std::size_t numPatterns = patterns.size1();
91 LabelType
const& firstResult = evalPattern(row(patterns,0));
96 for(std::size_t i = 0; i != numPatterns; ++i){
101 void eval(BatchInputType
const& patterns, BatchOutputType & outputs,
State& state)
const{
102 eval(patterns,outputs);
105 void eval(RealVector
const& pattern, LabelType& output){
106 output = evalPattern(pattern);
128 archive >> m_inputDimension;
135 archive << m_inputDimension;
141 UIntVector r(m_inputDimension, 0);
142 for(
auto it = m_tree.begin(); it != m_tree.end(); ++it) {
143 if(it->leftId != 0) {
144 r(it->attributeIndex)++;
152 return m_inputDimension;
163 return m_tree.size();
169 return m_tree[nodeId];
174 return m_tree[nodeId];
177 LabelType
const&
getLabel(std::size_t nodeId)
const{
179 return m_labels[m_tree[nodeId].rightIdOrIndex];
187 root.rightIdOrIndex = 0;
188 m_tree.push_back(root);
198 int nodeIdLeft = m_tree.size();
199 int nodeIdRight = m_tree.size() + 1;
203 leftChild.leftId = 0;
204 leftChild.rightIdOrIndex = 0;
207 rightChild.leftId = 0;
208 rightChild.rightIdOrIndex = 0;
210 m_tree.push_back(leftChild);
211 m_tree.push_back(rightChild);
214 m_tree[nodeId].leftId = nodeIdLeft;
215 m_tree[nodeId].rightIdOrIndex = nodeIdRight;
219 return m_tree[nodeId];
227 Node& node = m_tree[nodeId];
229 node.rightIdOrIndex = m_labels.size();
230 m_labels.push_back(label);
240 TreeType reordered_tree;
241 reordered_tree.reserve(m_tree.size());
243 std::deque<std::size_t > bfs_queue;
244 bfs_queue.push_back(0);
246 std::size_t nodeId = 0;
247 while(!bfs_queue.empty()){
248 Node
const& node =
getNode(bfs_queue.front());
249 bfs_queue.pop_front();
252 if(!node.leftId == 0){
253 reordered_tree.push_back(node);
255 reordered_tree.push_back(node);
256 reordered_tree.back().leftId = nodeId+1;
257 reordered_tree.back().rightIdOrIndex = nodeId+2;
259 bfs_queue.push_back(node.leftId);
260 bfs_queue.push_back(node.rightIdOrIndex);
264 m_tree = std::move(reordered_tree);
269 std::vector<LabelType> m_labels;
272 template<
class Vector>
273 LabelType
const& evalPattern(Vector
const& pattern)
const{
274 std::size_t nodeId = 0;
275 while(m_tree[nodeId].
leftId != 0){
278 nodeId = m_tree[nodeId].leftId;
281 nodeId = m_tree[nodeId].rightIdOrIndex;
284 return m_labels[m_tree[nodeId].rightIdOrIndex];
288 std::size_t m_inputDimension;