Dirichlet.h
Go to the documentation of this file.
1 /*!
2  *
3  *
4  * \brief Implements a dirichlet distribution.
5  *
6  *
7  *
8  * \author O. Krause
9  * \date 2010-01-01
10  *
11  *
12  * \par Copyright 1995-2017 Shark Development Team
13  *
14  * <BR><HR>
15  * This file is part of Shark.
16  * <http://shark-ml.org/>
17  *
18  * Shark is free software: you can redistribute it and/or modify
19  * it under the terms of the GNU Lesser General Public License as published
20  * by the Free Software Foundation, either version 3 of the License, or
21  * (at your option) any later version.
22  *
23  * Shark is distributed in the hope that it will be useful,
24  * but WITHOUT ANY WARRANTY; without even the implied warranty of
25  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26  * GNU Lesser General Public License for more details.
27  *
28  * You should have received a copy of the GNU Lesser General Public License
29  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
30  *
31  */
32 #ifndef SHARK_RNG_DIRICHLET_H
33 #define SHARK_RNG_DIRICHLET_H
34 
35 #include <shark/Rng/Gamma.h>
36 #include <shark/Rng/Rng.h>
37 
38 
39 #include <boost/math/special_functions.hpp>
40 #include <boost/random.hpp>
41 
42 #include <cmath>
43 #include <vector>
44 
45 #ifndef BOOST_RANDOM_NO_STREAM_OPERATORS
46 #include <iostream>
47 #endif
48 
49 namespace shark{
50 
51  //! \brief Dirichlet distribution
52  template<class RealType=double>
54  {
55  public:
56  typedef RealType input_type;
57  typedef std::vector<RealType> result_type;
58 
59  explicit Dirichlet_distribution(size_t n=3,RealType alpha=1)
60  :alphas_(n,alpha)
61  {}
62  explicit Dirichlet_distribution(const std::vector<RealType>& alphas)
63  :alphas_(alphas)
64  {}
65 
66  const std::vector<RealType>& alphas() const
67  {
68  return alphas_;
69  }
70 
71  void reset() { }
72 
73  template<class Engine>
74  result_type operator()(Engine& eng)const
75  {
76  unsigned n = alphas_.size();
77  RealType sum = 0;
78  std::vector<double> x;
79  x.resize(n);
80  for(size_t i=0; i<n; i++)
81  {
82  Gamma_distribution<> gamma(alphas_[i], 1.);
83  x[i] = gamma(eng);
84  sum += x[i];
85  }
86  for(size_t i=0; i<n; i++)
87  x[i]/= sum;
88  return x;
89  }
90 
91 #ifndef BOOST_RANDOM_NO_STREAM_OPERATORS
92  template<class CharT, class Traits>
93  friend std::basic_ostream<CharT,Traits>&
94  operator<<(std::basic_ostream<CharT,Traits>& os, const Dirichlet_distribution& d)
95  {
96  os << d.alphas.size();
97  for(int i=0;i!=d.alphas_.size();++i)
98  os << d.alphas_[i];
99  return os;
100  }
101 
102  template<class CharT, class Traits>
103  friend std::basic_istream<CharT,Traits>&
104  operator>>(std::basic_istream<CharT,Traits>& is, Dirichlet_distribution& d)
105  {
106  size_t size;
107  is >> size;
108  for(int i=0;i!=size;++i)
109  {
110  RealType element;
111  is >> element;
112  d.alphas_.push_back(element);
113  }
114  return is;
115  }
116 #endif
117  private:
118  std::vector<RealType> alphas_;
119  };
120 
121  /**
122  * \brief Implements a Dirichlet distribution.
123  * \tparam RngType The underlying generator type.
124  */
125  template<typename RngType = shark::DefaultRngType>
126  class Dirichlet:public boost::variate_generator<RngType*,Dirichlet_distribution<> >
127  {
128  private:
129  typedef boost::variate_generator<RngType*,Dirichlet_distribution<> > Base;
130  public:
131 
132  /**
133  * \brief C'tor, associates the distribution with the given generator.
134  * \param [in,out] rng Random number generator.
135  * \param [in] n Cardinality.
136  * \param [in] alpha Support value.
137  */
138  explicit Dirichlet(RngType& rng,size_t n=3,double alpha=1)
139  :Base(&rng,Dirichlet_distribution<>(n,alpha))
140  {}
141 
142  /**
143  * \brief C'tor, associates the distribution with the given generator.
144  * \param [in,out] rng Random number generator.
145  * \param [in] alphas Support values.
146  */
147  explicit Dirichlet(RngType& rng,const std::vector<double>& alphas)
148  :Base(&rng,Dirichlet_distribution<>(alphas))
149  {}
150 
151  /** \brief Injects the default sampling operator. */
152  using Base::operator();
153 
154  /**
155  * \brief Creates a temporary instance of the distribution and samples it.
156  * \param [in] n Cardinality.
157  * \param [in] alpha Support value.
158  */
159  std::vector<double> operator()(size_t n,double alpha) {
160  Dirichlet_distribution<> dist(n,alpha);
161  return dist(Base::engine());
162  }
163 
164  /**
165  * \brief Creates a temporary instance of the distribution and samples it.
166  * \param [in] alphas Support values.
167  */
168  std::vector<double> operator()(const std::vector<double> & alphas) {
169  Dirichlet_distribution<> dist(alphas);
170  return dist(Base::engine());
171  }
172 
173  /**
174  * \brief Accesses the support values.
175  */
176  const std::vector<double> alphas()const {
177  return Base::distribution().alphas();
178  }
179 
180  /**
181  * \brief Adjusts the support values.
182  * \param [in] newAlphas New support values.
183  */
184  void alphas(const std::vector<double>& newAlphas) {
185  Base::distribution()=Dirichlet_distribution<>(newAlphas);
186  }
187 
188  /**
189  * \brief Adjusts the support values.
190  * \param [in] n New cardinality.
191  * \param [in] alphas Support value.
192  */
193  void alphas(size_t n,double alphas) {
194  Base::distribution()=Dirichlet_distribution<>(n,alphas);
195  }
196 
197  /**
198  * \brief Calculates the probability of the observation x.
199  */
200  double p(const std::vector<double> &x)const
201  {
202  double p = 1.;
203  double sum = 0.;
204  for(int i=0; i<alphas().size(); i++)
205  {
206  p *= pow(x[i], alphas()[i]-1) / boost::math::tgamma(alphas()[i]);
207  sum += alphas()[i];
208  }
209  return p * boost::math::tgamma(sum);
210  }
211 
212  };
213 }
214 #endif