MarkovChain.h
Go to the documentation of this file.
1 /*!
2  *
3  *
4  * \brief -
5  *
6  * \author -
7  * \date -
8  *
9  *
10  * \par Copyright 1995-2017 Shark Development Team
11  *
12  * <BR><HR>
13  * This file is part of Shark.
14  * <http://shark-ml.org/>
15  *
16  * Shark is free software: you can redistribute it and/or modify
17  * it under the terms of the GNU Lesser General Public License as published
18  * by the Free Software Foundation, either version 3 of the License, or
19  * (at your option) any later version.
20  *
21  * Shark is distributed in the hope that it will be useful,
22  * but WITHOUT ANY WARRANTY; without even the implied warranty of
23  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24  * GNU Lesser General Public License for more details.
25  *
26  * You should have received a copy of the GNU Lesser General Public License
27  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
28  *
29  */
30 #ifndef SHARK_UNSUPERVISED_RBM_SAMPLING_MARKOVCHAIN_H
31 #define SHARK_UNSUPERVISED_RBM_SAMPLING_MARKOVCHAIN_H
32 
33 #include <shark/Data/Dataset.h>
34 #include <shark/Core/Random.h>
36 #include "Impl/SampleTypes.h"
37 namespace shark{
38 
39 /// \brief A single Markov chain.
40 ///
41 /// You can run the Markov chain for some sampling steps by applying a transition operator.
42 template<class Operator>
44 private:
45  typedef typename Operator::HiddenSample HiddenSample;
46  typedef typename Operator::VisibleSample VisibleSample;
47 public:
48 
49  ///\brief The MarkovChain can be used to compute several samples at once.
50  static const bool computesBatch = true;
51 
52  ///\brief The type of the RBM the operator is working with.
53  typedef typename Operator::RBM RBM;
54  ///\brief A batch of samples containing hidden and visible samples as well as the energies.
56 
57  ///\brief Mutable reference to an element of the batch.
58  typedef typename SampleBatch::reference reference;
59 
60  ///\brief Immutable reference to an element of the batch.
61  typedef typename SampleBatch::const_reference const_reference;
62 private:
63  ///\brief The batch of samples containing the state of the visible and the hidden units.
64  SampleBatch m_samples;
65  ///\brief The transition operator.
66  Operator m_operator;
67 public:
68 
69  /// \brief Constructor.
70  MarkovChain(RBM* rbm):m_operator(rbm){}
71 
72 
73  /// \brief Sets the number of parallel samples to be evaluated
74  void setBatchSize(std::size_t batchSize){
75  std::size_t visibles=m_operator.rbm()->numberOfVN();
76  std::size_t hiddens=m_operator.rbm()->numberOfHN();
77  m_samples=SampleBatch(batchSize,visibles,hiddens);
78  }
79  std::size_t batchSize(){
80  return m_samples.size();
81  }
82 
83  /// \brief Initializes with data points drawn uniform from the set.
84  ///
85  /// @param dataSet the data set
86  void initializeChain(Data<RealVector> const& dataSet){
87  std::size_t visibles=m_operator.rbm()->numberOfVN();
88  RealMatrix sampleData(m_samples.size(),visibles);
89 
90  for(std::size_t i = 0; i != m_samples.size(); ++i){
91  noalias(row(sampleData,i)) = dataSet.element(random::discrete(m_operator.rbm()->rng(),std::size_t(0),dataSet.numberOfElements()-1));
92  }
93  initializeChain(sampleData);
94  }
95 
96  /// \brief Initializes with data points from a batch of points
97  ///
98  /// @param sampleData Data set
99  void initializeChain(RealMatrix const& sampleData){
100  m_operator.createSample(m_samples.hidden,m_samples.visible,sampleData);
101  }
102 
103  /// \brief Runs the chain for a given number of steps.
104  ///
105  /// @param numberOfSteps the number of steps
106  void step(unsigned int numberOfSteps){
107  m_operator.stepVH(m_samples.hidden,m_samples.visible,numberOfSteps,blas::repeat(1.0,batchSize()));
108  }
109 
110  /// \brief Returns the current sample of the Markov chain.
111  const_reference sample()const{
112  return const_reference(m_samples,0);
113  }
114 
115  /// \brief Returns the current batch of samples of the Markov chain.
116  SampleBatch const& samples()const{
117  return m_samples;
118  }
119 
120  /// \brief Returns the current batch of samples of the Markov chain.
121  SampleBatch& samples(){
122  return m_samples;
123  }
124 
125  /// \brief Returns the transition operator of the Markov chain.
126  Operator const& transitionOperator()const{
127  return m_operator;
128  }
129 
130  /// \brief Returns the transition operator of the Markov chain.
131  Operator& transitionOperator(){
132  return m_operator;
133  }
134 };
135 
136 }
137 #endif