ContrastiveDivergence.h
Go to the documentation of this file.
1 /*!
2  *
3  *
4  * \brief -
5  *
6  * \author -
7  * \date -
8  *
9  *
10  * \par Copyright 1995-2017 Shark Development Team
11  *
12  * <BR><HR>
13  * This file is part of Shark.
14  * <http://shark-ml.org/>
15  *
16  * Shark is free software: you can redistribute it and/or modify
17  * it under the terms of the GNU Lesser General Public License as published
18  * by the Free Software Foundation, either version 3 of the License, or
19  * (at your option) any later version.
20  *
21  * Shark is distributed in the hope that it will be useful,
22  * but WITHOUT ANY WARRANTY; without even the implied warranty of
23  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24  * GNU Lesser General Public License for more details.
25  *
26  * You should have received a copy of the GNU Lesser General Public License
27  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
28  *
29  */
30 #ifndef SHARK_UNSUPERVISED_RBM_GRADIENTAPPROXIMATIONS_CONTRASTIVEDIVERGENCE_H
31 #define SHARK_UNSUPERVISED_RBM_GRADIENTAPPROXIMATIONS_CONTRASTIVEDIVERGENCE_H
32 
35 
36 namespace shark{
37 
38 /// \brief Implements k-step Contrastive Divergence described by Hinton et al. (2006).
39 ///
40 /// k-step Contrastive Divergence approximates the gradient by initializing a Gibbs
41 /// chain with a training example and run it for k steps.
42 /// The sample gained after k steps than samples is than used to approximate the mean of the RBM distribution in the gradient.
43 template<class Operator>
45 public:
46  typedef typename Operator::RBM RBM;
47 
48  /// \brief The constructor
49  ///
50  ///@param rbm pointer to the RBM which shell be trained
52  : mpe_rbm(rbm),m_operator(rbm)
53  , m_k(1), m_numBatches(0),m_regularizer(0){
54  SHARK_ASSERT(rbm != NULL);
55 
59  };
60 
61  /// \brief From INameable: return the class name.
62  std::string name() const
63  { return "ContrastiveDivergence"; }
64 
65  /// \brief Sets the training batch.
66  ///
67  /// @param data the batch of training data
69  m_data = data;
70  }
71 
72  /// \brief Sets the value of k- the number of steps of the Gibbs Chain
73  ///
74  /// @param k the number of steps
75  void setK(unsigned int k){
76  m_k = k;
77  }
78 
80  return mpe_rbm->parameterVector();
81  }
82 
83  /// \brief Returns the number of variables of the RBM.
84  ///
85  /// @return the number of variables of the RBM
86  std::size_t numberOfVariables()const{
87  return mpe_rbm->numberOfParameters();
88  }
89 
90  /// \brief Returns the number of batches of the dataset that are used in every iteration.
91  ///
92  /// If it is less than all batches, the batches are chosen at random. if it is 0, all batches are used
93  std::size_t numBatches()const{
94  return m_numBatches;
95  }
96 
97  /// \brief Returns a reference to the number of batches of the dataset that are used in every iteration.
98  ///
99  /// If it is less than all batches, the batches are chosen at random.if it is 0, all batches are used.
100  std::size_t& numBatches(){
101  return m_numBatches;
102  }
103 
104  void setRegularizer(double factor, SingleObjectiveFunction* regularizer){
105  m_regularizer = regularizer;
106  m_regularizationStrength = factor;
107  }
108 
109  /// \brief Gives the CD-k approximation of the log-likelihood gradient.
110  ///
111  /// @param parameter the actual parameters of the RBM
112  /// @param derivative holds later the CD-k approximation of the log-likelihood gradient
113  double evalDerivative( SearchPointType const & parameter, FirstOrderDerivative & derivative ) const{
114  mpe_rbm->setParameterVector(parameter);
115  derivative.resize(mpe_rbm->numberOfParameters());
116  derivative.clear();
117 
118  std::size_t batchesForTraining = m_numBatches > 0? m_numBatches: m_data.numberOfBatches();
119  std::size_t elements = 0;
120  //get the batches for this iteration
121  std::vector<std::size_t> batchIds(m_data.numberOfBatches());
122  {
123  for(std::size_t i = 0; i != m_data.numberOfBatches(); ++i){
124  batchIds[i] = i;
125  }
126  std::shuffle(batchIds.begin(),batchIds.end(),mpe_rbm->rng());
127  for(std::size_t i = 0; i != batchesForTraining; ++i){
128  elements += m_data.batch(batchIds[i]).size1();
129  }
130  }
131 
132  std::size_t threads = std::min<std::size_t>(batchesForTraining,SHARK_NUM_THREADS);
133  std::size_t numBatches = batchesForTraining/threads;
134 
135 
136  SHARK_PARALLEL_FOR(int t = 0; t < (int)threads; ++t){
137  typename RBM::GradientType empiricalAverage(mpe_rbm);
138  typename RBM::GradientType modelAverage(mpe_rbm);
139 
140  std::size_t threadElements = 0;
141 
142  std::size_t batchStart = t*numBatches;
143  std::size_t batchEnd = (t== (int)threads-1)? batchesForTraining : batchStart+numBatches;
144  for(std::size_t i = batchStart; i != batchEnd; ++i){
145  RealMatrix const& batch = m_data.batch(batchIds[i]);
146  threadElements += batch.size1();
147 
148  //create the batches for evaluation
149  typename Operator::HiddenSampleBatch hiddenBatch(batch.size1(),mpe_rbm->numberOfHN());
150  typename Operator::VisibleSampleBatch visibleBatch(batch.size1(),mpe_rbm->numberOfVN());
151 
152  visibleBatch.state = batch;
153  m_operator.precomputeHidden(hiddenBatch,visibleBatch,blas::repeat(1.0,batch.size1()));
154  m_operator.sampleHidden(hiddenBatch);
155  empiricalAverage.addVH(hiddenBatch,visibleBatch);
156 
157  for(std::size_t step = 0; step != m_k; ++step){
158  m_operator.precomputeVisible(hiddenBatch, visibleBatch,blas::repeat(1.0,batch.size1()));
159  m_operator.sampleVisible(visibleBatch);
160  m_operator.precomputeHidden(hiddenBatch, visibleBatch,blas::repeat(1.0,batch.size1()));
161  if( step != m_k-1){
162  m_operator.sampleHidden(hiddenBatch);
163  }
164  }
165  modelAverage.addVH(hiddenBatch,visibleBatch);
166  }
168  double weight = threadElements/double(elements);
169  noalias(derivative) += weight*(modelAverage.result() - empiricalAverage.result());
170  }
171 
172  }
173 
174  if(m_regularizer){
175  FirstOrderDerivative regularizerDerivative;
176  m_regularizer->evalDerivative(parameter,regularizerDerivative);
177  noalias(derivative) += m_regularizationStrength*regularizerDerivative;
178  }
179 
180  return std::numeric_limits<double>::quiet_NaN();
181  }
182 
183 private:
185  RBM* mpe_rbm;
186  Operator m_operator;
187  unsigned int m_k;
188  std::size_t m_numBatches;///< number of batches used in every iteration. 0 means all.
189 
190  SingleObjectiveFunction* m_regularizer;
191  double m_regularizationStrength;
192 };
193 
194 }
195 
196 #endif