30 #ifndef SHARK_UNSUPERVISED_RBM_GRADIENTAPPROXIMATIONS_CONTRASTIVEDIVERGENCE_H 31 #define SHARK_UNSUPERVISED_RBM_GRADIENTAPPROXIMATIONS_CONTRASTIVEDIVERGENCE_H 43 template<
class Operator>
46 typedef typename Operator::RBM
RBM;
52 : mpe_rbm(rbm),m_operator(rbm)
53 , m_k(1), m_numBatches(0),m_regularizer(0){
63 {
return "ContrastiveDivergence"; }
80 return mpe_rbm->parameterVector();
87 return mpe_rbm->numberOfParameters();
105 m_regularizer = regularizer;
106 m_regularizationStrength = factor;
114 mpe_rbm->setParameterVector(parameter);
115 derivative.resize(mpe_rbm->numberOfParameters());
118 std::size_t batchesForTraining = m_numBatches > 0? m_numBatches: m_data.
numberOfBatches();
119 std::size_t elements = 0;
126 std::shuffle(batchIds.begin(),batchIds.end(),mpe_rbm->rng());
127 for(std::size_t i = 0; i != batchesForTraining; ++i){
128 elements += m_data.
batch(batchIds[i]).size1();
132 std::size_t threads = std::min<std::size_t>(batchesForTraining,
SHARK_NUM_THREADS);
133 std::size_t
numBatches = batchesForTraining/threads;
140 std::size_t threadElements = 0;
143 std::size_t
batchEnd = (t== (int)threads-1)? batchesForTraining : batchStart+
numBatches;
144 for(std::size_t i = batchStart; i !=
batchEnd; ++i){
145 RealMatrix
const& batch = m_data.
batch(batchIds[i]);
146 threadElements += batch.size1();
149 typename Operator::HiddenSampleBatch hiddenBatch(batch.size1(),mpe_rbm->numberOfHN());
150 typename Operator::VisibleSampleBatch visibleBatch(batch.size1(),mpe_rbm->numberOfVN());
152 visibleBatch.state = batch;
153 m_operator.precomputeHidden(hiddenBatch,visibleBatch,blas::repeat(1.0,batch.size1()));
154 m_operator.sampleHidden(hiddenBatch);
155 empiricalAverage.addVH(hiddenBatch,visibleBatch);
157 for(std::size_t step = 0; step != m_k; ++step){
158 m_operator.precomputeVisible(hiddenBatch, visibleBatch,blas::repeat(1.0,batch.size1()));
159 m_operator.sampleVisible(visibleBatch);
160 m_operator.precomputeHidden(hiddenBatch, visibleBatch,blas::repeat(1.0,batch.size1()));
162 m_operator.sampleHidden(hiddenBatch);
165 modelAverage.addVH(hiddenBatch,visibleBatch);
168 double weight = threadElements/double(elements);
169 noalias(derivative) += weight*(modelAverage.result() - empiricalAverage.result());
177 noalias(derivative) += m_regularizationStrength*regularizerDerivative;
180 return std::numeric_limits<double>::quiet_NaN();
188 std::size_t m_numBatches;
191 double m_regularizationStrength;