CARTClassifier.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Cart Classifier
6  *
7  *
8  *
9  * \author K. N. Hansen, J. Kremer
10  * \date 2012
11  *
12  *
13  * \par Copyright 1995-2017 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://shark-ml.org/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_MODELS_TREES_CARTCLASSIFIER_H
36 #define SHARK_MODELS_TREES_CARTCLASSIFIER_H
37 
38 
42 #include <shark/Data/Dataset.h>
43 
44 namespace shark {
45 
46 
47 ///
48 /// \brief CART Classifier.
49 ///
50 /// \par
51 /// The CARTClassifier predicts a class label
52 /// using the CART algorithm.
53 ///
54 /// \par
55 /// It is a decision tree algorithm.
56 ///
57 template<class LabelType>
58 class CARTClassifier : public AbstractModel<RealVector,LabelType>
59 {
60 private:
62 public:
65 // Information about a single split. misclassProp, r and g are variables used in the cost complexity step
66  struct NodeInfo {
67  std::size_t nodeId;
68  std::size_t attributeIndex;
70  std::size_t leftNodeId;
71  std::size_t rightNodeId;
72  LabelType label;
73  double misclassProp;//TODO: remove this
74  std::size_t r;//TODO: remove this
75  double g;//TODO: remove this
76 
77  template<class Archive>
78  void serialize(Archive & ar, const unsigned int version){
79  ar & nodeId;
80  ar & attributeIndex;
81  ar & attributeValue;
82  ar & leftNodeId;
83  ar & rightNodeId;
84  ar & label;
85  ar & misclassProp;
86  ar & r;
87  ar & g;
88  }
89  NodeInfo() : nodeId(0), attributeIndex(0), attributeValue(0), leftNodeId(0), rightNodeId(0), misclassProp(0), r(0), g(0) {}
90 
91  explicit NodeInfo(std::size_t nodeId) : nodeId(nodeId), attributeIndex(0), attributeValue(0), leftNodeId(0), rightNodeId(0), misclassProp(0), r(0), g(0) {}
92 
93  NodeInfo(std::size_t nodeId, LabelType label) : nodeId(nodeId), attributeIndex(0), attributeValue(0), leftNodeId(0), rightNodeId(0), label(std::move(label)), misclassProp(0), r(0), g(0) {}
94 
95  NodeInfo(NodeInfo const&) = default;
96  NodeInfo& operator=(NodeInfo const&) = default;
98  : nodeId{n.nodeId}, attributeIndex{n.attributeIndex},
99  attributeValue{n.attributeValue}, leftNodeId{n.leftNodeId},
100  rightNodeId{n.rightNodeId}, label(std::move(n.label)),
101  misclassProp{n.misclassProp}, r{n.r}, g{n.g}
102  {}
104  {
105  nodeId = n.nodeId;
106  attributeIndex = n.attributeIndex;
107  attributeValue = n.attributeValue;
108  leftNodeId = n.leftNodeId;
109  rightNodeId = n.rightNodeId;
110  label = std::move(n.label);
111  misclassProp = n.misclassProp;
112  r = n.r;
113  g = n.g;
114  return *this;
115  }
116  };
117 
118  /// Vector of structs that contains the splitting information and the labels.
119  /// The class label is a normalized histogram in the classification case.
120  /// In the regression case, the label is the regression value.
121  typedef std::vector<NodeInfo> TreeType;
122 
123  /// Constructor
125  {}
126 
127  /// Constructor taking the tree as argument
128  explicit CARTClassifier(TreeType const& tree)
129  : m_tree(tree), m_inputDimension(0), m_OOBerror(0)
130  { }
131  explicit CARTClassifier(TreeType&& tree)
132  : m_tree(std::move(tree)), m_inputDimension(0), m_OOBerror(0)
133  { }
134 
135  /// Constructor taking the tree as argument and optimize it if requested
136  CARTClassifier(TreeType const& tree, bool optimize) : CARTClassifier()
137  {
138  if (optimize)
139  setTree(tree);
140  else
141  m_tree=tree;
142  }
143 
144  /// Constructor taking the tree as argument as well as maximum number of attributes
145  CARTClassifier(TreeType const& tree, std::size_t d)
146  : m_OOBerror(0), m_tree{tree}, m_inputDimension{d}
147  {
149  }
150 
151  CARTClassifier(TreeType&& tree, std::size_t d) BOOST_NOEXCEPT_IF((std::is_nothrow_constructible<TreeType,TreeType>::value))
152  : m_tree(std::move(tree)), m_inputDimension{d}, m_OOBerror{0}
153  {
155  }
156 
157  /// \brief From INameable: return the class name.
158  std::string name() const
159  { return "CARTClassifier"; }
160 
161  boost::shared_ptr<State> createState() const{
162  return boost::shared_ptr<State>(new EmptyState());
163  }
164 
165  using base_type::eval;
166  /// \brief Evaluate the Tree on a batch of patterns
167  void eval(BatchInputType const& patterns, BatchOutputType & outputs) const{
168  std::size_t numPatterns = patterns.size1();
169  //evaluate the first pattern alone and create the batch output from that
170  LabelType const& firstResult = evalPattern(row(patterns,0));
171  outputs = Batch<LabelType>::createBatch(firstResult,numPatterns);
172  getBatchElement(outputs,0) = firstResult;
173 
174  //evaluate the rest
175  for(std::size_t i = 0; i != numPatterns; ++i){
176  getBatchElement(outputs,i) = evalPattern(row(patterns,i));
177  }
178  }
179 
180  void eval(BatchInputType const& patterns, BatchOutputType & outputs, State& state) const{
181  eval(patterns,outputs);
182  }
183  /// \brief Evaluate the Tree on a single pattern
184  void eval(RealVector const& pattern, LabelType& output){
185  output = evalPattern(pattern);
186  }
187 
188  /// Set the model tree.
189  void setTree(TreeType const& tree){
190  m_tree = tree;
192  }
193 
194  /// Get the model tree.
195  TreeType getTree() const {
196  return m_tree;
197  }
198 
199  /// \brief The model does not have any parameters.
200  std::size_t numberOfParameters() const{
201  return 0;
202  }
203 
204  /// \brief The model does not have any parameters.
205  RealVector parameterVector() const {
206  return RealVector();
207  }
208 
209  /// \brief The model does not have any parameters.
210  void setParameterVector(RealVector const& param) {
211  SHARK_ASSERT(param.size() == 0);
212  }
213 
214  /// from ISerializable, reads a model from an archive
215  void read(InArchive& archive){
216  archive >> m_tree;
217  }
218 
219  /// from ISerializable, writes a model to an archive
220  void write(OutArchive& archive) const {
221  archive << m_tree;
222  }
223 
224 
225  //Count how often attributes are used
226  UIntVector countAttributes() const {
228  UIntVector r(m_inputDimension, 0);
229  typename TreeType::const_iterator it;
230  for(it = m_tree.begin(); it != m_tree.end(); ++it) {
231  //std::cout << "NodeId: " <<it->leftNodeId << std::endl;
232  if(it->leftNodeId != 0) { // not a label
233  r(it->attributeIndex)++;
234  }
235  }
236  return r;
237  }
238 
239  ///Return input dimension
240  std::size_t inputSize() const {
241  return m_inputDimension;
242  }
243 
244  //Set input dimension
245  void setInputDimension(std::size_t d) {
246  m_inputDimension = d;
247  }
248 
249  /// Compute oob error, given an oob dataset (Classification)
251  // define loss
253 
254  // predict oob data
255  Data<RealVector> predOOB = (*this)(dataOOB.inputs());
256 
257  // count average number of oob misclassifications
258  m_OOBerror = lossOOB.eval(dataOOB.labels(), predOOB);
259  }
260 
261  /// Compute oob error, given an oob dataset (Regression)
262  void computeOOBerror(RegressionDataset const& dataOOB){
263  // define loss
265 
266  // predict oob data
267  Data<RealVector> predOOB = (*this)(dataOOB.inputs());
268 
269  // Compute mean squared error
270  m_OOBerror = lossOOB.eval(dataOOB.labels(), predOOB);
271  }
272 
273  /// Return OOB error
274  double OOBerror() const {
275  return m_OOBerror;
276  }
277 
278  /// Return feature importances
279  RealVector const& featureImportances() const {
280  return m_featureImportances;
281  }
282 
283  /// Compute feature importances, given an oob dataset (Classification)
284  void computeFeatureImportances(ClassificationDataset const& dataOOB, random::rng_type& rng){
286 
287  // define loss
289 
290  // compute oob error
291  computeOOBerror(dataOOB);
292 
293  // count average number of correct oob predictions
294  double accuracyOOB = 1. - m_OOBerror;
295 
296  // go through all dimensions, permute each dimension across all elements and train the tree on it
297  for(std::size_t i=0;i!=m_inputDimension;++i) {
298  // create permuted dataset by copying
299  ClassificationDataset pDataOOB(dataOOB);
300  pDataOOB.makeIndependent();
301 
302  // permute current dimension
303  RealVector v = getColumn(pDataOOB.inputs(), i);
304  std::shuffle(v.begin(), v.end(), rng);
305  setColumn(pDataOOB.inputs(), i, v);
306 
307  // evaluate the data set for which one feature dimension was permuted with this tree
308  Data<RealVector> pPredOOB = (*this)(pDataOOB.inputs());
309 
310  // count the number of correct predictions
311  double accuracyPermutedOOB = 1. - lossOOB.eval(pDataOOB.labels(),pPredOOB);
312 
313  // store importance
314  m_featureImportances[i] = std::fabs(accuracyOOB - accuracyPermutedOOB);
315  }
316  }
317 
318  /// Compute feature importances, given an oob dataset (Regression)
319  void computeFeatureImportances(RegressionDataset const& dataOOB, random::rng_type& rng){
321 
322  // define loss
324 
325  // compute oob error
326  computeOOBerror(dataOOB);
327 
328  // mean squared error for oob sample
329  double mseOOB = m_OOBerror;
330 
331  // go through all dimensions, permute each dimension across all elements and train the tree on it
332  for(std::size_t i=0;i!=m_inputDimension;++i) {
333  // create permuted dataset by copying
334  RegressionDataset pDataOOB(dataOOB);
335  pDataOOB.makeIndependent();
336 
337  // permute current dimension
338  RealVector v = getColumn(pDataOOB.inputs(), i);
339  std::shuffle(v.begin(), v.end(), rng);
340  setColumn(pDataOOB.inputs(), i, v);
341 
342  // evaluate the data set for which one feature dimension was permuted with this tree
343  Data<RealVector> pPredOOB = (*this)(pDataOOB.inputs());
344 
345  // mean squared error of permuted oob sample
346  double msePermutedOOB = lossOOB.eval(pDataOOB.labels(),pPredOOB);
347 
348  // store importance
349  m_featureImportances[i] = std::fabs(msePermutedOOB - mseOOB);
350  }
351  }
352 
353 protected:
354  /// tree of the model
355  TreeType m_tree;
356 
357  /// \brief Finds the index of the node with a certain nodeID in an unoptimized tree.
358  std::size_t findNode(std::size_t nodeId) const{
359  std::size_t index = 0;
360  for(; nodeId != m_tree[index].nodeId; ++index);
361  return index;
362  }
363 
364  /// Optimize a tree, so constant lookup can be used.
365  /// The optimization is done by changing the index of the children
366  /// to use indices instead of node ID.
367  /// Furthermore, the node IDs are converted to index numbers.
368  void optimizeTree(TreeType & tree) const{
369  for(std::size_t i = 0; i < tree.size(); i++){
370  tree[i].leftNodeId = findNode(tree[i].leftNodeId);
371  tree[i].rightNodeId = findNode(tree[i].rightNodeId);
372  }
373  for(std::size_t i = 0; i < tree.size(); i++){
374  tree[i].nodeId = i;
375  }
376  }
377 
378  /// Evaluate the CART tree on a single sample
379  template<class Vector>
380  LabelType const& evalPattern(Vector const& pattern) const{
381  std::size_t nodeId = 0;
382  while(m_tree[nodeId].leftNodeId != 0){
383  if(pattern[m_tree[nodeId].attributeIndex]<= m_tree[nodeId].attributeValue){
384  //Branch on left node
385  nodeId = m_tree[nodeId].leftNodeId;
386  }else{
387  //Branch on right node
388  nodeId = m_tree[nodeId].rightNodeId;
389  }
390  }
391  return m_tree[nodeId].label;
392  }
393 
394 
395  ///Number of attributes (set by trainer)
396  std::size_t m_inputDimension;
397 
398  // feature importances
400 
401  // oob error
402  double m_OOBerror;
403 };
404 
405 
406 }
407 #endif