32 #ifndef SHARK_OBJECTIVEFUNCTIONS_NEGATIVE_LOG_LIKELIHOOD_H 33 #define SHARK_OBJECTIVEFUNCTIONS_NEGATIVE_LOG_LIKELIHOOD_H 66 DatasetType
const& data,
70 ):mep_decoder(decoder), mep_encoder(encoder), mep_loss(visible_loss), m_data(data){
79 {
return "VariationalAutoencoderError"; }
96 RealMatrix hiddenResponse = (*mep_encoder)(batch);
97 auto const& mu = columns(hiddenResponse,0,hiddenResponse.size2()/2);
98 auto const& log_var = columns(hiddenResponse,hiddenResponse.size2()/2, hiddenResponse.size2());
100 double klError = 0.5 * (sum(exp(log_var)) + sum(
sqr(mu)) - mu.size1() * mu.size2() - sum(log_var));
102 RealMatrix epsilon(mu.size1(), mu.size2());
103 for(std::size_t i = 0; i != epsilon.size1(); ++i){
104 for(std::size_t j = 0; j != epsilon.size2(); ++j){
108 RealMatrix z = mu + exp(0.5*log_var) * epsilon;
110 RealMatrix reconstruction = (*mep_decoder)(z);
111 return ((*mep_loss)(batch, reconstruction) + klError) / batch.size1();
124 boost::shared_ptr<State> stateEncoder = mep_encoder->
createState();
125 boost::shared_ptr<State> stateDecoder = mep_decoder->
createState();
127 RealMatrix hiddenResponse;
128 mep_encoder->
eval(batch,hiddenResponse,*stateEncoder);
129 auto const& mu = columns(hiddenResponse,0,hiddenResponse.size2()/2);
130 auto const& log_var = columns(hiddenResponse,hiddenResponse.size2()/2, hiddenResponse.size2());
132 double klError = 0.5 * (sum(exp(log_var)) + sum(
sqr(mu)) - mu.size1() * mu.size2() - sum(log_var));
133 RealMatrix klDerivative = mu | (0.5 * exp(log_var) - 0.5);
134 RealMatrix epsilon(mu.size1(), mu.size2());
135 for(std::size_t i = 0; i != epsilon.size1(); ++i){
136 for(std::size_t j = 0; j != epsilon.size2(); ++j){
140 RealMatrix z = mu + exp(0.5*log_var) * epsilon;
141 RealMatrix reconstructions;
142 mep_decoder->
eval(z,reconstructions, *stateDecoder);
146 RealMatrix lossDerivative;
147 double recError = mep_loss->
evalDerivative(batch,reconstructions,lossDerivative);
149 RealVector derivativeDecoder;
150 RealMatrix backpropDecoder;
151 mep_decoder->
weightedDerivatives(z,reconstructions, lossDerivative,*stateDecoder, derivativeDecoder, backpropDecoder);
154 RealMatrix backprop=(backpropDecoder | (backpropDecoder * 0.5*(z - mu))) + klDerivative;
155 RealVector derivativeEncoder;
159 noalias(derivative) = derivativeDecoder|derivativeEncoder;
160 derivative /= batch.size1();
161 return (recError + klError) / batch.size1();