ConvexCombination.h
Go to the documentation of this file.
1 /*!
2  *
3  *
4  * \brief Implements a Model using a linear function.
5  *
6  *
7  *
8  * \author T. Glasmachers, O. Krause
9  * \date 2010-2011
10  *
11  *
12  * \par Copyright 1995-2017 Shark Development Team
13  *
14  * <BR><HR>
15  * This file is part of Shark.
16  * <http://shark-ml.org/>
17  *
18  * Shark is free software: you can redistribute it and/or modify
19  * it under the terms of the GNU Lesser General Public License as published
20  * by the Free Software Foundation, either version 3 of the License, or
21  * (at your option) any later version.
22  *
23  * Shark is distributed in the hope that it will be useful,
24  * but WITHOUT ANY WARRANTY; without even the implied warranty of
25  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26  * GNU Lesser General Public License for more details.
27  *
28  * You should have received a copy of the GNU Lesser General Public License
29  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
30  *
31  */
32 #ifndef SHARK_MODELS_ConvexCombination_H
33 #define SHARK_MODELS_ConvexCombination_H
34 
36 namespace shark {
37 
38 
39 ///
40 /// \brief Models a convex combination of inputs
41 ///
42 /// For a given input vector x, the convex combination returns \f$ f_i(x) = sum_j w_{ij} x_j \f$,
43 /// where \f$ w_i > 0 \f$ and \f$ sum_j w_{ij} = 1\f$, that is the outputs of
44 /// the model are a convex combination of the inputs.
45 ///
46 /// To ensure that the constraints are fulfilled, the model uses a different
47 /// set of weights q_i and \f$ w_{ij} = exp(q_{ij})/sum_j exp(q_{ik}) \f$. As usual, this
48 /// encoding is only used for the derivatives and the parameter vectors, not
49 /// when the weights are explicitely set. In the latter case, the user must provide
50 /// a set of suitable \f$ w_{ij} \f$.
51 class ConvexCombination : public AbstractModel<RealVector,RealVector>
52 {
53 private:
54  RealMatrix m_w; ///< the convex comination weights. it holds sum(row(w_i)) = 1
55 public:
56 
57  /// CDefault Constructor; use setStructure later
61  }
62 
63  /// Constructor creating a model with given dimnsionalities and optional offset term.
64  ConvexCombination(std::size_t inputs, std::size_t outputs = 1)
65  : m_w(outputs,inputs,0.0){
68  }
69 
70  /// Construction from matrix
71  ConvexCombination(RealMatrix const& matrix):m_w(matrix){
74  }
75 
76  /// \brief From INameable: return the class name.
77  std::string name() const
78  { return "ConvexCombination"; }
79 
80  ///swap
81  friend void swap(ConvexCombination& model1,ConvexCombination& model2){
82  swap(model1.m_w,model2.m_w);
83  }
84 
85  ///operator =
87  ConvexCombination tempModel(model);
88  swap(*this,tempModel);
89  return *this;
90  }
91 
92  /// obtain the input dimension
93  std::size_t inputSize() const{
94  return m_w.size2();
95  }
96 
97  /// obtain the output dimension
98  std::size_t outputSize() const{
99  return m_w.size1();
100  }
101 
102  /// obtain the parameter vector
103  RealVector parameterVector() const{
104  return to_vector(log(m_w));
105  }
106 
107  /// overwrite the parameter vector
108  void setParameterVector(RealVector const& newParameters)
109  {
110  noalias(m_w) = exp(to_matrix(newParameters,m_w.size1(),m_w.size2()));
111  for(std::size_t i = 0; i != outputSize(); ++i){
112  row(m_w,i) /= sum(row(m_w,i));
113  }
114  }
115 
116  /// return the number of parameter
117  std::size_t numberOfParameters() const{
118  return m_w.size1()*m_w.size2();
119  }
120 
121  /// overwrite structure and parameters
122  void setStructure(std::size_t inputs, std::size_t outputs = 1){
123  ConvexCombination model(inputs,outputs);
124  swap(*this,model);
125  }
126 
127  RealMatrix const& weights() const{
128  return m_w;
129  }
130 
131  RealMatrix& weights(){
132  return m_w;
133  }
134 
135  boost::shared_ptr<State> createState()const{
136  return boost::shared_ptr<State>(new EmptyState());
137  }
138 
139  /// Evaluate the model: output = w * input
140  void eval(BatchInputType const& inputs, BatchOutputType& outputs)const{
141  outputs.resize(inputs.size1(),m_w.size1());
142  noalias(outputs) = prod(inputs,trans(m_w));
143  }
144  /// Evaluate the model: output = w *input
145  void eval(BatchInputType const& inputs, BatchOutputType& outputs, State& state)const{
146  eval(inputs,outputs);
147  }
148 
149  ///\brief Calculates the first derivative w.r.t the parameters and summing them up over all patterns of the last computed batch
151  BatchInputType const& patterns, RealMatrix const& coefficients, State const& state, RealVector& gradient
152  )const{
153  SIZE_CHECK(coefficients.size2()==outputSize());
154  SIZE_CHECK(coefficients.size1()==patterns.size1());
155 
156  gradient.resize(numberOfParameters());
157  blas::dense_matrix_adaptor<double> weightGradient = blas::to_matrix(gradient,outputSize(),inputSize());
158 
159  //derivative is
160  //sum_i sum_j c_ij sum_k x_ik grad_q w_jk= sum_k sum_j grad_q w_jk (sum_i c_ij x_ik)
161  //and we set d_jk=sum_i c_ij x_ik => d = C^TX
162  RealMatrix d = prod(trans(coefficients), patterns);
163 
164  //use the same drivative as in the softmax model
165  for(std::size_t i = 0; i != outputSize(); ++i){
166  double mass=inner_prod(row(d,i),row(m_w,i));
167  noalias(row(weightGradient,i)) = element_prod(
168  row(d,i) - mass,
169  row(m_w,i)
170  );
171  }
172  }
173  ///\brief Calculates the first derivative w.r.t the inputs and summs them up over all patterns of the last computed batch
175  BatchInputType const & patterns,
176  BatchOutputType const & coefficients,
177  State const& state,
178  BatchInputType& derivative
179  )const{
180  SIZE_CHECK(coefficients.size2() == outputSize());
181  SIZE_CHECK(coefficients.size1() == patterns.size1());
182 
183  derivative.resize(patterns.size1(),inputSize());
184  noalias(derivative) = prod(coefficients,m_w);
185  }
186 
187  /// From ISerializable
188  void read(InArchive& archive){
189  archive >> m_w;
190  }
191  /// From ISerializable
192  void write(OutArchive& archive) const{
193  archive << m_w;
194  }
195 };
196 
197 
198 }
199 #endif