SvmLogisticInterpretation.h
Go to the documentation of this file.
1 /*!
2  *
3  *
4  * \brief Maximum-likelihood model selection for binary support vector machines.
5  *
6  *
7  *
8  * \author M.Tuma, T.Glasmachers
9  * \date 2009-2012
10  *
11  *
12  * \par Copyright 1995-2017 Shark Development Team
13  *
14  * <BR><HR>
15  * This file is part of Shark.
16  * <http://shark-ml.org/>
17  *
18  * Shark is free software: you can redistribute it and/or modify
19  * it under the terms of the GNU Lesser General Public License as published
20  * by the Free Software Foundation, either version 3 of the License, or
21  * (at your option) any later version.
22  *
23  * Shark is distributed in the hope that it will be useful,
24  * but WITHOUT ANY WARRANTY; without even the implied warranty of
25  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26  * GNU Lesser General Public License for more details.
27  *
28  * You should have received a copy of the GNU Lesser General Public License
29  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
30  *
31  */
32 #ifndef SHARK_ML_SVMLOGISTICINTERPRETATION_H
33 #define SHARK_ML_SVMLOGISTICINTERPRETATION_H
34 
42 
43 namespace shark {
44 
45 ///
46 /// \brief Maximum-likelihood model selection score for binary support vector machines
47 ///
48 /// \par
49 /// This class implements the maximum-likelihood based SVM model selection
50 /// procedure presented in the article "Glasmachers and C. Igel. Maximum
51 /// Likelihood Model Selection for 1-Norm Soft Margin SVMs with Multiple
52 /// Parameters. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2010."
53 /// At this point, only binary C-SVMs are supported.
54 /// \par
55 /// This class implements an AbstactObjectiveFunction. In detail, it provides
56 /// a differentiable measure of how well a C-SVM with given hyperparameters fulfills
57 /// the maximum-likelihood score presented in the paper. This error measure can then
58 /// be optimized for externally via gradient-based optimizers. In other words, this
59 /// class provides a score, not an optimization method or a training algorithm. The
60 /// C-SVM parameters have to be optimized with regard to this measure
61 ///
62 template<class InputType = RealVector>
64 public:
67 protected:
68  FoldsType m_folds; ///< the underlying partitioned dataset.
69  KernelType *mep_kernel; ///< the kernel with which to run the SVM
70  std::size_t m_nhp; ///< for convenience, the Number of Hyper Parameters
71  std::size_t m_nkp; ///< for convenience, the Number of Kernel Parameters
72  std::size_t m_numFolds; ///< the number of folds to be used in cross-validation
73  std::size_t m_numSamples; ///< overall number of samples in the dataset
74  std::size_t m_inputDims; ///< input dimensionality
75  bool m_svmCIsUnconstrained; ///< the SVM regularization parameter C is passed for unconstrained optimization, and the derivative should compensate for that
76  QpStoppingCondition *mep_svmStoppingCondition; ///< the stopping criterion that is to be passed to the SVM trainer.
77 public:
78 
79  //! constructor.
80  //! \param folds an already partitioned dataset (i.e., a CVFolds object)
81  //! \param kernel pointer to the kernel to be used within the SVMs.
82  //! \param unconstrained whether or not the C-parameter of/for the C-SVM is passed for unconstrained optimization mode.
83  //! \param stop_cond the stopping conditions which are to be passed to the
85  FoldsType const &folds, KernelType *kernel,
86  bool unconstrained = true, QpStoppingCondition *stop_cond = NULL
87  )
88  : mep_kernel(kernel)
89  , m_nhp(kernel->parameterVector().size()+1)
90  , m_nkp(kernel->parameterVector().size())
91  , m_numFolds(folds.size()) //gets number of folds!
92  , m_numSamples(folds.dataset().numberOfElements())
93  , m_inputDims(inputDimension(folds.dataset()))
94  , m_svmCIsUnconstrained(unconstrained)
95  , mep_svmStoppingCondition(stop_cond)
96  {
97  SHARK_RUNTIME_CHECK(kernel != NULL, "[SvmLogisticInterpretation::SvmLogisticInterpretation] kernel is not allowed to be NULL"); //mtq: necessary despite indirect check via call in initialization list?
98  SHARK_RUNTIME_CHECK(m_numFolds > 1, "[SvmLogisticInterpretation::SvmLogisticInterpretation] please provide a meaningful number of folds for cross validation");
99  if (!m_svmCIsUnconstrained) //mtq: important: we additionally need to deal with kernel feasibility indicators! important!
102  if (mep_kernel->hasFirstParameterDerivative())
104  m_folds = folds;
105  }
106 
107  /// \brief From INameable: return the class name.
108  std::string name() const
109  { return "SvmLogisticInterpretation"; }
110 
111  //! checks whether the search point provided is feasible
112  //! \param input the point to test for feasibility
113  bool isFeasible(const SearchPointType &input) const {
114  SHARK_ASSERT(input.size() == m_nhp);
115  if (input(0) <= 0.0 && !m_svmCIsUnconstrained) {
116  return false;
117  }
118  return true;
119  }
120 
121  std::size_t numberOfVariables()const{
122  return m_nhp;
123  }
124 
125  //! train a number of SVMs in a cross-validation setting using the hyperparameters passed to this method.
126  //! the output scores from all validations sets are then concatenated. together with the true labels, these
127  //! scores can then be used to fit a sigmoid such that it becomes as good as possible a model for the
128  //! class membership probabilities given the SVM output scores. This method returns the negative likelihood
129  //! of the best fitting sigmoid, given a set of SVM hyperparameters.
130  //! \param parameters the SVM hyperparameters to use for all C-SVMs
131  double eval(SearchPointType const &parameters) const {
132  SHARK_RUNTIME_CHECK(m_nhp == parameters.size(), "[SvmLogisticInterpretation::eval] wrong number of parameters");
133  // initialize, copy parameters
134  double C_reg = (m_svmCIsUnconstrained ? std::exp(parameters(m_nkp)) : parameters(m_nkp)); //set up regularization parameter
135  mep_kernel->setParameterVector(subrange(parameters, 0, m_nkp)); //set up kernel parameters
136  // Stores the stacked CV predictions for every fold.
137  ClassificationDataset validation_dataset;
138  // for each fold, train an svm and get predictions on the validation data
139  for (std::size_t i=0; i<m_numFolds; i++) {
140  // init SVM
142  CSvmTrainer<InputType, double> csvm_trainer(mep_kernel, C_reg, true, m_svmCIsUnconstrained); //the trainer
143  csvm_trainer.sparsify() = false;
144  if (mep_svmStoppingCondition != NULL) {
146  }
147 
148  // train SVM on current training fold
149  csvm_trainer.train(svm, m_folds.training(i));
150 
151  //append validation predictions
152  validation_dataset.append(transformInputs(m_folds.validation(i),svm.decisionFunction()));
153  }
154 
155  // Fit a logistic regression to the prediction
156  LinearModel<> logistic_model = fitLogistic(validation_dataset);
157 
158  //to evaluate, we use cross entropy loss on the fitted model
159  CrossEntropy logistic_loss;
160  return logistic_loss(validation_dataset.labels(),logistic_model(validation_dataset.inputs()));
161  }
162 
163  //! the derivative of the error() function above w.r.t. the parameters.
164  //! \param parameters the SVM hyperparameters to use for all C-SVMs
165  //! \param derivative will store the computed derivative w.r.t. the current hyperparameters
166  // mtq: should this also follow the first-call-error()-then-call-deriv() paradigm?
167  double evalDerivative(SearchPointType const &parameters, FirstOrderDerivative &derivative) const {
168  SHARK_RUNTIME_CHECK(m_nhp == parameters.size(), "[SvmLogisticInterpretation::evalDerivative] wrong number of parameters");
169  // initialize, copy parameters
170  double C_reg = (m_svmCIsUnconstrained ? std::exp(parameters(m_nkp)) : parameters(m_nkp)); //set up regularization parameter
171  mep_kernel->setParameterVector(subrange(parameters, 0, m_nkp)); //set up kernel parameters
172  // these two will be filled in order corresp. to all CV validation partitions stacked
173  // behind one another, and then used to create datasets with
174  std::vector< unsigned int > tmp_helper_labels(m_numSamples);
175  std::vector< RealVector > tmp_helper_preds(m_numSamples);
176 
177  unsigned int next_label = 0; //helper index counter to monitor the next position to be filled in the above vectors
178  // init variables especially for derivative
179  RealMatrix all_validation_predict_derivs(m_numSamples, m_nhp); //will hold derivatives of all output scores w.r.t. all hyperparameters
180  RealVector der; //temporary helper for derivative calls
181 
182  // for each fold, train an svm and get predictions on the validation data
183  for (std::size_t i=0; i<m_numFolds; i++) {
184  // get current train/validation partitions as well as corresponding labels
185  ClassificationDataset cur_train_data = m_folds.training(i);
186  ClassificationDataset cur_valid_data = m_folds.validation(i);
187  std::size_t cur_vsize = cur_valid_data.numberOfElements();
188  Data< unsigned int > cur_vlabels = cur_valid_data.labels(); //validation labels of this fold
189  Data< RealVector > cur_vinputs = cur_valid_data.inputs(); //validation inputs of this fold
190  Data< RealVector > cur_vscores; //will hold SVM output scores for current validation partition
191  // init SVM
192  KernelClassifier<InputType> svm; //the SVM
193  CSvmTrainer<InputType, double> csvm_trainer(mep_kernel, C_reg, true, m_svmCIsUnconstrained); //the trainer
194  csvm_trainer.sparsify() = false;
195  csvm_trainer.setComputeBinaryDerivative(true);
196  if (mep_svmStoppingCondition != NULL) {
198  }
199  // train SVM on current fold
200  csvm_trainer.train(svm, cur_train_data);
201  CSvmDerivative<InputType> svm_deriv(&svm, &csvm_trainer);
202  cur_vscores = svm.decisionFunction()(cur_valid_data.inputs()); //will result in a dataset of RealVector as output
203  // copy the scores and corresponding labels to the dataset-wide storage
204  for (std::size_t j=0; j<cur_vsize; j++) {
205  // copy label and prediction score
206  tmp_helper_labels[next_label] = cur_vlabels.element(j);
207  tmp_helper_preds[next_label] = cur_vscores.element(j);
208  // get and store the derivative of the score w.r.t. the hyperparameters
209  svm_deriv.modelCSvmParameterDerivative(cur_vinputs.element(j), der);
210  noalias(row(all_validation_predict_derivs, next_label)) = der; //fast assignment of the derivative to the correct matrix row
211  ++next_label;
212  }
213  }
214 
215  // now we got it all: the predictions across the validation folds, plus the correct corresponding
216  // labels. so we go ahead and fit a logistic regression
217  ClassificationDataset validation_dataset= createLabeledDataFromRange(tmp_helper_preds, tmp_helper_labels);
218  LinearModel<> logistic_model = fitLogistic(validation_dataset);
219 
220  // to evaluate, we use cross entropy loss on the fitted model and compute
221  // the derivative wrt the svm model parameters.
222  derivative.resize(m_nhp);
223  derivative.clear();
224  double error = 0;
225  std::size_t start = 0;
226  for(auto const& batch: validation_dataset.batches()){
227  std::size_t end = start+batch.size();
228  CrossEntropy logistic_loss;
229  RealMatrix lossGradient;
230  error += logistic_loss.evalDerivative(batch.label,logistic_model(batch.input),lossGradient);
231  noalias(derivative) += column(lossGradient,0) % rows(all_validation_predict_derivs,start,end);
232  start = end;
233  }
234  derivative *= logistic_model.parameterVector()(0);
235  derivative /= m_numSamples;
236  return error / m_numSamples;
237  }
238 private:
239  LinearModel<> fitLogistic(ClassificationDataset const& data)const{
240  LinearModel<> logistic_model;
241  logistic_model.setStructure(1,1, true);//1 input, 1 output, bias = 2 parameters
242  CrossEntropy logistic_loss;
243  ErrorFunction error(data, &logistic_model, & logistic_loss);
244  BFGS optimizer;
245  optimizer.init(error);
246  //this converges after very few iterations (typically 20 function evaluations)
247  while(norm_2(optimizer.derivative())> 1.e-8){
248  double lastValue = optimizer.solution().value;
249  optimizer.step(error);
250  if(lastValue == optimizer.solution().value) break;//we are done due to numerical precision
251  }
252  logistic_model.setParameterVector(optimizer.solution().point);
253  return logistic_model;
254  }
255 };
256 
257 
258 }
259 #endif