Computes the variational autoencoder error function. More...
#include <shark/ObjectiveFunctions/VariationalAutoencoderError.h>
Public Types | |
typedef UnlabeledData< RealVector > | DatasetType |
Public Types inherited from shark::AbstractObjectiveFunction< RealVector, double > | |
enum | Feature |
List of features that are supported by an implementation. More... | |
typedef RealVector | SearchPointType |
typedef double | ResultType |
typedef boost::mpl::if_< std::is_arithmetic< double >, SearchPointType, RealMatrix >::type | FirstOrderDerivative |
typedef TypedFlags< Feature > | Features |
This statement declares the member m_features. See Core/Flags.h for details. More... | |
typedef TypedFeatureNotAvailableException< Feature > | FeatureNotAvailableException |
Public Member Functions | |
VariationalAutoencoderError (DatasetType const &data, AbstractModel< RealVector, RealVector > *encoder, AbstractModel< RealVector, RealVector > *decoder, AbstractLoss< RealVector, RealVector > *visible_loss) | |
std::string | name () const |
From INameable: return the class name. More... | |
SearchPointType | proposeStartingPoint () const |
Proposes a starting point in the feasible search space of the function. More... | |
std::size_t | numberOfVariables () const |
Accesses the number of variables. More... | |
ResultType | eval (RealVector const ¶meters) const |
Evaluates the objective function for the supplied argument. More... | |
ResultType | evalDerivative (SearchPointType const ¶meters, FirstOrderDerivative &derivative) const |
Evaluates the objective function and calculates its gradient. More... | |
Public Member Functions inherited from shark::AbstractObjectiveFunction< RealVector, double > | |
const Features & | features () const |
virtual void | updateFeatures () |
bool | hasValue () const |
returns whether this function can calculate it's function value More... | |
bool | hasFirstDerivative () const |
returns whether this function can calculate the first derivative More... | |
bool | hasSecondDerivative () const |
returns whether this function can calculate the second derivative More... | |
bool | canProposeStartingPoint () const |
returns whether this function can propose a starting point. More... | |
bool | isConstrained () const |
returns whether this function can return More... | |
bool | hasConstraintHandler () const |
returns whether this function can return More... | |
bool | canProvideClosestFeasible () const |
Returns whether this function can calculate thee closest feasible to an infeasible point. More... | |
bool | isThreadSafe () const |
Returns true, when the function can be usd in parallel threads. More... | |
bool | isNoisy () const |
Returns true, when the function can be usd in parallel threads. More... | |
AbstractObjectiveFunction () | |
Default ctor. More... | |
virtual | ~AbstractObjectiveFunction () |
Virtual destructor. More... | |
virtual void | init () |
void | setRng (random::rng_type *rng) |
Sets the Rng used by the objective function. More... | |
virtual bool | hasScalableDimensionality () const |
virtual void | setNumberOfVariables (std::size_t numberOfVariables) |
Adjusts the number of variables if the function is scalable. More... | |
virtual std::size_t | numberOfObjectives () const |
virtual bool | hasScalableObjectives () const |
virtual void | setNumberOfObjectives (std::size_t numberOfObjectives) |
Adjusts the number of objectives if the function is scalable. More... | |
std::size_t | evaluationCounter () const |
Accesses the evaluation counter of the function. More... | |
AbstractConstraintHandler< SearchPointType > const & | getConstraintHandler () const |
Returns the constraint handler of the function if it has one. More... | |
virtual bool | isFeasible (const SearchPointType &input) const |
Tests whether a point in SearchSpace is feasible, e.g., whether the constraints are fulfilled. More... | |
virtual void | closestFeasible (SearchPointType &input) const |
If supported, the supplied point is repaired such that it satisfies all of the function's constraints. More... | |
ResultType | operator() (SearchPointType const &input) const |
Evaluates the function. Useful together with STL-Algorithms like std::transform. More... | |
virtual ResultType | evalDerivative (SearchPointType const &input, SecondOrderDerivative &derivative) const |
Evaluates the objective function and calculates its gradient. More... | |
Public Member Functions inherited from shark::INameable | |
virtual | ~INameable () |
Additional Inherited Members | |
Protected Member Functions inherited from shark::AbstractObjectiveFunction< RealVector, double > | |
void | announceConstraintHandler (AbstractConstraintHandler< SearchPointType > const *handler) |
helper function which is called to announce the presence of an constraint handler. More... | |
Protected Attributes inherited from shark::AbstractObjectiveFunction< RealVector, double > | |
Features | m_features |
std::size_t | m_evaluationCounter |
Evaluation counter, default value: 0. More... | |
AbstractConstraintHandler< SearchPointType > const * | m_constraintHandler |
random::rng_type * | mep_rng |
Computes the variational autoencoder error function.
We want to optimize a model \( p(x) = \int p(x|z) p(z) dz \) where we choose p(z) as a multivariate normal distribution and p(x|z) is an arbitrary model, e.g. a deep neural entwork. The naive solution is sampling from p(z) and then compute the sample average. This will fail when p(z|x) is a very localized distribution and we might need many samples from p(z) to find a sample which is likely under p(z|x). p(z|x) is assumed to be intractable to compute, so we introduce a second model q(z|x), modeling p(z|x) and we want to train it such that it learns the unknown p(z|x). For this a variational lower bound on the likelihood is used and we maximize
\[ log p(x) \leq E_{q(z|x)}[\log p(x|z)] - KL[q(z|x) || p(z)] \]
The first term explains the meaning of variational autoencoder: we first sample z given x using the encoder model q and then decode z to obtain an estimate for x. The only difference to normal autoencoders is that we now have a probabilistic z. The second term ensures that q is learning p(z|x), assuming that we have enough modeling capacity to actually learn it. See https://arxiv.org/abs/1606.05908 for more background.
Implementation notice: we assume q(z|x) to be a set of independent gaussian distributions parameterized as \( q(z| mu(x), \log \sigma^2(x)) \). The provided encoder model q must therefore have twice as many outputs as the decvoder has inputs as the second half of outputs is interpreted as the log of the variance. So if z should be a 100 dimensional variable, q must have 200 outputs. The outputs and loss function used for the encoder p is arbitrary, but a SquaredLoss will work well, however also other losses like pixel probabilities can be used.
Definition at line 60 of file VariationalAutoencoderError.h.
typedef UnlabeledData<RealVector> shark::VariationalAutoencoderError::DatasetType |
Definition at line 63 of file VariationalAutoencoderError.h.
|
inline |
Definition at line 65 of file VariationalAutoencoderError.h.
References shark::AbstractObjectiveFunction< RealVector, double >::CAN_PROPOSE_STARTING_POINT, shark::AbstractObjectiveFunction< RealVector, double >::HAS_FIRST_DERIVATIVE, shark::AbstractModel< InputTypeT, OutputTypeT, ParameterType >::hasFirstParameterDerivative(), shark::AbstractObjectiveFunction< RealVector, double >::IS_NOISY, and shark::AbstractObjectiveFunction< RealVector, double >::m_features.
|
inlinevirtual |
Evaluates the objective function for the supplied argument.
[in] | input | The argument for which the function shall be evaluated. |
FeatureNotAvailableException | in the default implementation and if a function does not support this feature. |
Reimplemented from shark::AbstractObjectiveFunction< RealVector, double >.
Definition at line 89 of file VariationalAutoencoderError.h.
References shark::AbstractObjectiveFunction< RealVector, double >::m_evaluationCounter, numberOfVariables(), shark::IParameterizable< VectorType >::setParameterVector(), and SIZE_CHECK.
|
inlinevirtual |
Evaluates the objective function and calculates its gradient.
[in] | input | The argument to eval the function for. |
[out] | derivative | The derivate is placed here. |
FeatureNotAvailableException | in the default implementation and if a function does not support this feature. |
Reimplemented from shark::AbstractObjectiveFunction< RealVector, double >.
Definition at line 115 of file VariationalAutoencoderError.h.
References shark::AbstractObjectiveFunction< RealVector, double >::m_evaluationCounter, numberOfVariables(), shark::IParameterizable< VectorType >::setParameterVector(), and SIZE_CHECK.
|
inlinevirtual |
From INameable: return the class name.
Reimplemented from shark::INameable.
Definition at line 78 of file VariationalAutoencoderError.h.
|
inlinevirtual |
Accesses the number of variables.
Implements shark::AbstractObjectiveFunction< RealVector, double >.
Definition at line 85 of file VariationalAutoencoderError.h.
References shark::IParameterizable< VectorType >::numberOfParameters().
Referenced by eval(), and evalDerivative().
|
inlinevirtual |
Proposes a starting point in the feasible search space of the function.
FeatureNotAvailableException | in the default implementation and if a function does not support this feature. |
Reimplemented from shark::AbstractObjectiveFunction< RealVector, double >.
Definition at line 81 of file VariationalAutoencoderError.h.
References shark::IParameterizable< VectorType >::parameterVector().