CARTree.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Cart Classifier
6  *
7  *
8  *
9  * \author K. N. Hansen, J. Kremer
10  * \date 2012
11  *
12  *
13  * \par Copyright 1995-2017 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://shark-ml.org/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_MODELS_TREES_CARTree_H
36 #define SHARK_MODELS_TREES_CARTree_H
37 
38 
40 #include <shark/Data/Dataset.h>
41 namespace shark {
42 
43 
44 ///
45 /// \brief Classification and Regression Tree.
46 ///
47 /// \par
48 /// The CARTree predicts a class label using a decision tree
49 template<class LabelType>
50 class CARTree : public AbstractModel<RealVector,LabelType>
51 {
52 private:
54 public:
57 
58  struct Node{
59  std::size_t attributeIndex;
61  std::size_t leftId;
62  std::size_t rightIdOrIndex;
63 
64  template<class Archive>
65  void serialize(Archive & ar, const unsigned int version){
66  ar & attributeIndex;
67  ar & attributeValue;
68  ar & leftId;
69  ar & rightIdOrIndex;///< either id of right node or index to label array
70  }
71  };
72  typedef std::vector<Node> TreeType;
73 
74 
75  /// Constructor
76  CARTree(std::size_t inputDimension = 0) : m_inputDimension(inputDimension){}
77 
78  /// \brief From INameable: return the class name.
79  std::string name() const
80  { return "CARTree"; }
81 
82  boost::shared_ptr<State> createState() const{
83  return boost::shared_ptr<State>(new EmptyState());
84  }
85 
86  using base_type::eval;
87  /// \brief Evaluate the Tree on a batch of patterns
88  void eval(BatchInputType const& patterns, BatchOutputType & outputs) const{
89  std::size_t numPatterns = patterns.size1();
90  //evaluate the first pattern alone and create the batch output from that
91  LabelType const& firstResult = evalPattern(row(patterns,0));
92  outputs = Batch<LabelType>::createBatch(firstResult,numPatterns);
93  getBatchElement(outputs,0) = firstResult;
94 
95  //evaluate the rest
96  for(std::size_t i = 0; i != numPatterns; ++i){
97  getBatchElement(outputs,i) = evalPattern(row(patterns,i));
98  }
99  }
100 
101  void eval(BatchInputType const& patterns, BatchOutputType & outputs, State& state) const{
102  eval(patterns,outputs);
103  }
104  /// \brief Evaluate the Tree on a single pattern
105  void eval(RealVector const& pattern, LabelType& output){
106  output = evalPattern(pattern);
107  }
108 
109  /// \brief The model does not have any parameters.
110  std::size_t numberOfParameters() const{
111  return 0;
112  }
113 
114  /// \brief The model does not have any parameters.
115  RealVector parameterVector() const {
116  return RealVector();
117  }
118 
119  /// \brief The model does not have any parameters.
120  void setParameterVector(RealVector const& param) {
121  SHARK_ASSERT(param.size() == 0);
122  }
123 
124  /// from ISerializable, reads a model from an archive
125  void read(InArchive& archive){
126  archive >> m_tree;
127  archive >> m_labels;
128  archive >> m_inputDimension;
129  }
130 
131  /// from ISerializable, writes a model to an archive
132  void write(OutArchive& archive) const {
133  archive << m_tree;
134  archive << m_labels;
135  archive << m_inputDimension;
136  }
137 
138  //Count how often attributes are used
139  UIntVector countAttributes() const {
140  SHARK_ASSERT(m_inputDimension > 0);
141  UIntVector r(m_inputDimension, 0);
142  for(auto it = m_tree.begin(); it != m_tree.end(); ++it) {
143  if(it->leftId != 0) { // not a label
144  r(it->attributeIndex)++;
145  }
146  }
147  return r;
148  }
149 
150  ///Return input dimension
151  Shape inputShape() const {
152  return m_inputDimension;
153  }
155  return Shape();
156  }
157 
158  ////////////////////////////////
159  /////Tree Construction routines
160  ///////////////////////////////
161 
162  std::size_t numberOfNodes() const{
163  return m_tree.size();
164  }
165 
166  /// \brief Returns the node with id nodeId
167  Node& getNode(std::size_t nodeId){
168  SIZE_CHECK(nodeId < m_tree.size());
169  return m_tree[nodeId];
170  }
171  /// \brief Returns the node with id nodeId
172  Node const& getNode(std::size_t nodeId)const{
173  SIZE_CHECK(nodeId < m_tree.size());
174  return m_tree[nodeId];
175  }
176 
177  LabelType const& getLabel(std::size_t nodeId)const{
178  SIZE_CHECK(nodeId < m_tree.size());
179  return m_labels[m_tree[nodeId].rightIdOrIndex];
180  }
181 
182  /// \brief Creates and returns an untyped root node (neither internal, nor leaf node)
183  Node& createRoot(){
184  m_tree.clear();
185  Node root;
186  root.leftId = 0;
187  root.rightIdOrIndex = 0;
188  m_tree.push_back(root);
189  return m_tree[0];
190  }
191 
192 
193  ///\brief Transforms an untyped node (no child, no internal node) into an internal node
194  ///
195  /// This creates already the two childs of the node, which are untyped.
196  Node& transformInternalNode(std::size_t nodeId, std::size_t attributeIndex, double attributeValue) {
197  // ids for new child nodes
198  int nodeIdLeft = m_tree.size();
199  int nodeIdRight = m_tree.size() + 1;
200 
201  //create new child nodes
202  Node leftChild;
203  leftChild.leftId = 0;
204  leftChild.rightIdOrIndex = 0;
205 
206  Node rightChild;
207  rightChild.leftId = 0;
208  rightChild.rightIdOrIndex = 0;
209 
210  m_tree.push_back(leftChild);
211  m_tree.push_back(rightChild);
212 
213  // connect the parent node with its two childs
214  m_tree[nodeId].leftId = nodeIdLeft;
215  m_tree[nodeId].rightIdOrIndex = nodeIdRight;
216  m_tree[nodeId].attributeIndex = attributeIndex;
217  m_tree[nodeId].attributeValue = attributeValue;
218 
219  return m_tree[nodeId];
220  }
221  ///\brief Transforms a node (no leaf) into a leaf node and inserts the appropriate label
222  ///
223  /// If the node was an internal node before, its connections get removed and the childs
224  /// are not reachable any more. Calling a reorder routine like reorderBFS() will get rid of those
225  /// nodes.
226  Node& transformLeafNode(std::size_t nodeId, LabelType const& label){
227  Node& node = m_tree[nodeId];
228  node.leftId = 0;
229  node.rightIdOrIndex = m_labels.size();
230  m_labels.push_back(label);
231  return node;
232  }
233 
234  /// \brief Reorders a tree into a breath-first-ordering
235  ///
236  /// This function call will remove all unreachable subtrees while reordering
237  /// the nodes by their depth in the tree, i.e. first comes the root, the the children
238  /// of the root, than their children, etc.
239  void reorderBFS(){
240  TreeType reordered_tree;
241  reordered_tree.reserve(m_tree.size());
242 
243  std::deque<std::size_t > bfs_queue;
244  bfs_queue.push_back(0);
245 
246  std::size_t nodeId = 0; //running id of the next node to insert
247  while(!bfs_queue.empty()){
248  Node const& node = getNode(bfs_queue.front());
249  bfs_queue.pop_front();
250 
251  //check leaf
252  if(!node.leftId == 0){
253  reordered_tree.push_back(node);
254  }else{
255  reordered_tree.push_back(node);
256  reordered_tree.back().leftId = nodeId+1;
257  reordered_tree.back().rightIdOrIndex = nodeId+2;
258  nodeId += 2;
259  bfs_queue.push_back(node.leftId);
260  bfs_queue.push_back(node.rightIdOrIndex);
261  }
262  }
263  //overwrite old tree with pruned tree
264  m_tree = std::move(reordered_tree);
265  }
266 private:
267  /// tree of the model
268  TreeType m_tree;
269  std::vector<LabelType> m_labels;
270 
271  /// Evaluate the CART tree on a single sample
272  template<class Vector>
273  LabelType const& evalPattern(Vector const& pattern) const{
274  std::size_t nodeId = 0;
275  while(m_tree[nodeId].leftId != 0){
276  if(pattern[m_tree[nodeId].attributeIndex] <= m_tree[nodeId].attributeValue){
277  //Branch on left node
278  nodeId = m_tree[nodeId].leftId;
279  }else{
280  //Branch on right node
281  nodeId = m_tree[nodeId].rightIdOrIndex;
282  }
283  }
284  return m_labels[m_tree[nodeId].rightIdOrIndex];
285  }
286 
287  ///Number of attributes (set by trainer)
288  std::size_t m_inputDimension;
289 };
290 
291 
292 }
293 #endif