AbstractLoss.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  * \brief super class of all loss functions
4  *
5  * \author T. Glasmachers
6  * \date 2010-2011
7  *
8  * \par Copyright (c) 2010-2011:
9  * Institut f&uuml;r Neuroinformatik<BR>
10  * Ruhr-Universit&auml;t Bochum<BR>
11  * D-44780 Bochum, Germany<BR>
12  * Phone: +49-234-32-25558<BR>
13  * Fax: +49-234-32-14209<BR>
14  * eMail: Shark-admin@neuroinformatik.ruhr-uni-bochum.de<BR>
15  * www: http://www.neuroinformatik.ruhr-uni-bochum.de<BR>
16  *
17  * <BR><HR>
18  * This file is part of Shark. This library is free software;
19  * you can redistribute it and/or modify it under the terms of the
20  * GNU General Public License as published by the Free Software
21  * Foundation; either version 3, or (at your option) any later version.
22  *
23  * This library is distributed in the hope that it will be useful,
24  * but WITHOUT ANY WARRANTY; without even the implied warranty of
25  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26  * GNU General Public License for more details.
27  *
28  * You should have received a copy of the GNU General Public License
29  * along with this library; if not, see <http://www.gnu.org/licenses/>.
30  *
31  */
32 
33 #ifndef SHARK_OBJECTIVEFUNCTIONS_LOSS_ABSTRACTLOSS_H
34 #define SHARK_OBJECTIVEFUNCTIONS_LOSS_ABSTRACTLOSS_H
35 
37 #include <shark/LinAlg/Base.h>
38 namespace shark {
39 
40 
41 /// \brief Loss function interface
42 ///
43 /// \par
44 /// In statistics and machine learning, a loss function encodes
45 /// the severity of getting a label wrong. This is am important
46 /// special case of a cost function (see AbstractCost), where
47 /// the cost is computed as the average loss over a set, also
48 /// known as (empirical) risk.
49 ///
50 /// \par
51 /// It is generally agreed that loss values are non-negative,
52 /// and that the loss of correct prediction is zero. This rule
53 /// is not formally checked, but instead left to the various
54 /// sub-classes.
55 ///
56 template<class LabelT, class OutputT = LabelT>
57 class AbstractLoss : public AbstractCost<LabelT, OutputT>
58 {
59 public:
61  typedef OutputT OutputType;
62  typedef LabelT LabelType;
64 
67 
70  }
71 
72  /// \brief evaluate the loss for a batch of targets and a prediction
73  ///
74  /// \param target target values
75  /// \param prediction predictions, typically made by a model
76  virtual double eval( BatchLabelType const& target, BatchOutputType const& prediction) const = 0;
77 
78  /// \brief evaluate the loss for a target and a prediction
79  ///
80  /// \param target target value
81  /// \param prediction prediction, typically made by a model
82  virtual double eval( LabelType const& target, OutputType const& prediction)const{
83  BatchLabelType labelBatch = Batch<LabelType>::createBatch(target,1);
84  get(labelBatch,0)=target;
85  BatchOutputType predictionBatch = Batch<OutputType>::createBatch(prediction,1);
86  get(predictionBatch,0)=prediction;
87  return eval(labelBatch,predictionBatch);
88  }
89 
90  /// \brief evaluate the loss and the derivative w.r.t. the prediction
91  ///
92  /// \par
93  /// The default implementations throws an exception.
94  /// If you overwrite this method, don't forget to set
95  /// the flag HAS_FIRST_DERIVATIVE.
96  /// \param target target value
97  /// \param prediction prediction, typically made by a model
98  /// \param gradient the gradient of the loss function with respect to the prediction
99  virtual double evalDerivative(BatchLabelType const& target, BatchOutputType const& prediction, BatchOutputType& gradient) const
100  {
102  return 0.0; // dead code, prevent warning
103  }
104 
105  //~ /// \brief evaluate the loss and fist and second derivative w.r.t. the prediction
106  //~ ///
107  //~ /// \par
108  //~ /// The default implementations throws an exception.
109  //~ /// If you overwrite this method, don't forget to set
110  //~ /// the flag HAS_FIRST_DERIVATIVE.
111  //~ /// \param target target value
112  //~ /// \param prediction prediction, typically made by a model
113  //~ /// \param gradient the gradient of the loss function with respect to the prediction
114  //~ /// \param hessian the hessian matrix of the loss function with respect to the prediction
115  //~ virtual double evalDerivative(
116  //~ LabelType const& target,
117  //~ OutputType const& prediction,
118  //~ OutputType& gradient,
119  //~ MatrixType& hessian) const
120  //~ {
121  //~ SHARK_FEATURE_EXCEPTION_DERIVED(HAS_SECOND_DERIVATIVE);
122  //~ return 0.0; // dead code, prevent warning
123  //~ }
124 
125  /// from AbstractCost
126  ///
127  /// \param targets target values
128  /// \param predictions predictions, typically made by a model
129  double eval(Data<LabelType> const& targets, Data<OutputType> const& predictions) const{
130  SIZE_CHECK(predictions.numberOfElements() == targets.numberOfElements());
131  SIZE_CHECK(predictions.numberOfBatches() == targets.numberOfBatches());
132  int numBatches = (int) targets.numberOfBatches();
133  double error = 0;
134  SHARK_PARALLEL_FOR(int i = 0; i < numBatches; ++i){
135  double batchError= eval(targets.batch(i),predictions.batch(i));
137  error+=batchError;
138  }
139  }
140  return error / targets.numberOfElements();
141  }
142 
143  /// from AbstractCost
144  ///
145  /// \param targets target values
146  /// \param predictions predictions, typically made by a model
147  /// \param gradient the gradient of the cost function with respect to the predictions
149  Data<LabelType> const& targets,
150  Data<OutputType> const& predictions,
151  Data<OutputType>& gradient
152  ) const{
154  SIZE_CHECK(predictions.numberOfElements() == targets.numberOfElements());
155 
156  int numBatches = (int) targets.numberOfBatches();
157  std::size_t elements = targets.numberOfElements();
158  gradient = Data<OutputType>(numBatches);
159 
160  double error = 0;
161  SHARK_PARALLEL_FOR(int i = 0; i < numBatches; ++i){
162  double batchError= evalDerivative(targets.batch(i),predictions.batch(i),gradient.batch(i));
163  gradient.batch(i) /= elements;//we return the mean of the loss
165  error+=batchError;
166  }
167  }
168  return error/elements;
169  }
170 
171  /// \brief evaluate the loss for a target and a prediction
172  ///
173  /// \par
174  /// convenience operator
175  ///
176  /// \param target target value
177  /// \param prediction prediction, typically made by a model
178  double operator () (LabelType const& target, OutputType const& prediction) const
179  { return eval(target, prediction); }
180 
181  double operator () (BatchLabelType const& target, BatchOutputType const& prediction) const
182  { return eval(target, prediction); }
183 
184  using base_type::operator();
185 };
186 
187 
188 }
189 #endif