ROC.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief ROC
6  *
7  *
8  *
9  * \author O.Krause
10  * \date 2010-2011
11  *
12  *
13  * \par Copyright 1995-2017 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://shark-ml.org/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 #ifndef SHARK_OBJECTIVEFUNCTIONS_ROC_H
35 #define SHARK_OBJECTIVEFUNCTIONS_ROC_H
36 
37 #include <shark/Core/DLLSupport.h>
39 #include <shark/Data/Dataset.h>
40 #include <vector>
41 #include <algorithm>
42 
43 namespace shark {
44 
45 //!
46 //! \brief ROC-Curve - false negatives over false positives
47 //!
48 //! \par
49 //! This class provides the ROC curve of a classifier.
50 //! All time consuming computations are done in the constructor,
51 //! such that afterwards fast access to specific values of the
52 //! curve and the equal error rate is possible.
53 //!
54 //! \par
55 //! The ROC class assumes a one dimensional target array and a
56 //! model producing one dimensional output data. The targets must
57 //! be the labels 0 and 1 of a binary classification task. The
58 //! model output is assumed not to be 0 and 1, but real valued
59 //! instead. Classification in done by thresholding, where
60 //! different false positive and false negative rates correspond
61 //! to different thresholds. The ROC curve shows the trade off
62 //! between the two error types.
63 //!
64 class ROC
65 {
66 public:
67  //! Constructor
68  //!
69  //! \param model model to use for prediction
70  //! \param set data set with inputs and corresponding binary outputs (0 or 1)
71  template<class InputType>
73 
74  //calculat the number of classes
75  std::vector<std::size_t> classes = classSizes(set);
76  SIZE_CHECK(classes.size() == 2); //only binary problems allowed!
77 
78  std::size_t positive = classes[0];
79  std::size_t negative = classes[1];
80  m_scorePositive.resize(positive);
81  m_scoreNegative.resize(negative);
82 
83  // compute scores
84  std::size_t posPositive = 0;
85  std::size_t posNegative = 0;
86 
87  //calculate the model responses batchwise for the whole set
88  for(std::size_t i = 0; i != set.size(); ++i){
89  RealMatrix output = model(set.batch(i).input);
90  SIZE_CHECK(output.size2() == 1);
91  for(std::size_t j = 0; j != output.size1(); ++j){
92  double value = output(j,0);
93  if (set.batch(i)(j) == 1)
94  {
95  m_scorePositive[posPositive] = value;
96  posPositive++;
97  }
98  else
99  {
100  m_scoreNegative[posNegative] = value;
101  posNegative++;
102  }
103  }
104  }
105  // sort positives and negatives by score
106  std::sort(m_scorePositive.begin(), m_scorePositive.end());
107  std::sort(m_scoreNegative.begin(), m_scoreNegative.end());
108  }
109 
110  //! Compute the threshold for given false acceptance rate,
111  //! that is, for a given false positive rate.
112  //! This threshold, used for classification with the underlying
113  //! model, results in the given false acceptance rate.
114  SHARK_EXPORT_SYMBOL double threshold(double falseAcceptanceRate)const;
115 
116  //! Value of the ROC curve for given false acceptance rate,
117  //! that is, for a given false positive rate.
118  SHARK_EXPORT_SYMBOL double value(double falseAcceptanceRate)const;
119 
120  //! Computes the equal error rate of the classifier
121  SHARK_EXPORT_SYMBOL double equalErrorRate()const;
122 
123 protected:
124  //! scores of the positive examples
125  std::vector<double> m_scorePositive;
126 
127  //! scores of the negative examples
128  std::vector<double> m_scoreNegative;
129 };
130 
131 }
132 #endif