MultiNomialDistribution.h
Go to the documentation of this file.
1 /*!
2  *
3  *
4  * \brief Implements a multinomial distribution
5  *
6  *
7  *
8  * \author O.Krause
9  * \date 2016
10  *
11  *
12  * \par Copyright 1995-2017 Shark Development Team
13  *
14  * <BR><HR>
15  * This file is part of Shark.
16  * <http://shark-ml.org/>
17  *
18  * Shark is free software: you can redistribute it and/or modify
19  * it under the terms of the GNU Lesser General Public License as published
20  * by the Free Software Foundation, either version 3 of the License, or
21  * (at your option) any later version.
22  *
23  * Shark is distributed in the hope that it will be useful,
24  * but WITHOUT ANY WARRANTY; without even the implied warranty of
25  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26  * GNU Lesser General Public License for more details.
27  *
28  * You should have received a copy of the GNU Lesser General Public License
29  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
30  *
31  */
32 #ifndef SHARK_STATISTICS_MULTINOMIALDISTRIBUTION_H
33 #define SHARK_STATISTICS_MULTINOMIALDISTRIBUTION_H
34 
35 #include <shark/LinAlg/Base.h>
36 #include <shark/Core/Random.h>
37 
38 namespace shark {
39 
40 /// \brief Implements a multinomial distribution.
41 ///
42 /// A multinomial distribution is a discrete distribution with states 0,...,N-1
43 /// and probabilities p_i for state i with sum_i p_i = 1. This implementation uses
44 /// the fast alias method (Kronmal and Peterson,1979) to draw the numbers in
45 /// constant time. Setup is O(N) and also quite fast. It is advisable
46 /// to use this method to draw many numbers in succession.
47 ///
48 /// The idea of the alias method is to pair a state with high probability with a state with low
49 /// probability. A high probability state can in this case be included in several pairs. To draw,
50 /// first one of the states is selected and afterwards a coin toss decides which element of the pair
51 /// is taken.
53 public:
54  typedef std::size_t result_type;
55 
57 
58  /// \brief Constructor
59  /// \param [in] probabilities Probability vector
61  : m_probabilities(probabilities){
62  update();
63  }
64 
65  /// \brief Stores/Restores the distribution from the supplied archive.
66  /// \param [in,out] ar The archive to read from/write to.
67  /// \param [in] version Currently unused.
68  template<typename Archive>
69  void serialize( Archive & ar, const unsigned int version ) {
70  ar & BOOST_SERIALIZATION_NVP( m_probabilities );
71  ar & BOOST_SERIALIZATION_NVP( m_q );
72  ar & BOOST_SERIALIZATION_NVP( m_J );
73  }
74 
75  /// \brief Accesses the probabilityvector defining the distribution.
76  RealVector const& probabilities() const {
77  return m_probabilities;
78  }
79 
80  /// \brief Accesses a mutable reference to the probability vector
81  /// defining the distribution. Allows for l-value semantics.
82  ///
83  /// ATTENTION: If the reference is altered, update needs to be called manually.
84  RealVector& probabilities() {
85  return m_probabilities;
86  }
87 
88  /// \brief Samples the distribution.
89  template<class randomType>
90  result_type operator()(randomType& rng) const {
91  std::size_t numStates = m_probabilities.size();
92 
93  std::size_t index = random::discrete(rng,std::size_t(0),numStates-1);
94 
95  if(random::coinToss(rng, m_q[index]))
96  return index;
97  else
98  return m_J[index];
99  }
100 
101 
102  void update() {
103  std::size_t numStates = m_probabilities.size();
104  m_q.resize(numStates);
105  m_J.resize(numStates);
106  m_probabilities/=sum(m_probabilities);
107 
108  // Sort the data into the outcomes with probabilities
109  // that are larger and smaller than 1/K.
110  std::deque<std::size_t> smaller;
111  std::deque<std::size_t> larger;
112  for(std::size_t i = 0;i != numStates; ++i){
113  m_q(i) = numStates*m_probabilities(i);
114  if(m_q(i) < 1.0)
115  smaller.push_back(i);
116  else
117  larger.push_back(i);
118  }
119  // Loop though and create little binary mixtures that
120  // appropriately allocate the larger outcomes over the
121  // overall uniform mixture.
122  while(!smaller.empty() && !larger.empty()){
123  std::size_t smallIndex = smaller.front();
124  std::size_t largeIndex = larger.front();
125  smaller.pop_front();
126  larger.pop_front();
127 
128  m_J[smallIndex] = largeIndex;
129  m_q[largeIndex] -= 1.0 - m_q[smallIndex];
130 
131  if(m_q[largeIndex] < 1.0)
132  smaller.push_back(largeIndex);
133  else
134  larger.push_back(largeIndex);
135  }
136  for(std::size_t i = 0; i != larger.size(); ++i){
137  m_q[larger[i]]=std::min(m_q[larger[i]],1.0);
138  }
139  }
140 
141 private:
142  RealVector m_probabilities; ///< probability of every state.
143  RealVector m_q; ///< probability of the pair (i,J[i]) to draw an.
144  blas::vector<std::size_t> m_J; ///< defines the second element of the pair (i,J[i])
145 };
146 }
147 
148 #endif