A recurrent neural network regression model optimized for online learning. More...
#include <shark/Models/OnlineRNNet.h>
Public Member Functions | |
SHARK_EXPORT_SYMBOL | OnlineRNNet (RecurrentStructure *structure, bool computeGradient) |
structure The structure of the OnlineRNNet More... | |
std::string | name () const |
From INameable: return the class name. More... | |
SHARK_EXPORT_SYMBOL void | eval (RealMatrix const &pattern, RealMatrix &output, State &state) const |
Feeds a timestep of a time series to the model and calculates it's output. The batches must have size 1. More... | |
SHARK_EXPORT_SYMBOL void | eval (RealMatrix const &pattern, RealMatrix &output) const |
It is forbidding to call eval without a state object. More... | |
std::size_t | inputSize () const |
obtain the input dimension More... | |
std::size_t | outputSize () const |
obtain the output dimension More... | |
SHARK_EXPORT_SYMBOL void | weightedParameterDerivative (RealMatrix const &pattern, RealMatrix const &coefficients, State const &state, RealVector &gradient) const |
calculates the weighted sum of gradients w.r.t the parameters More... | |
RealVector | parameterVector () const |
get internal parameters of the model More... | |
void | setParameterVector (RealVector const &newParameters) |
set internal parameters of the model More... | |
std::size_t | numberOfParameters () const |
number of parameters of the network More... | |
boost::shared_ptr< State > | createState () const |
Creates an internal state of the model. More... | |
void | setOutputActivation (State &state, RealVector const &activation) |
This Method sets the activation of the output neurons. More... | |
Public Member Functions inherited from shark::AbstractModel< RealVector, RealVector > | |
AbstractModel () | |
virtual | ~AbstractModel () |
const Features & | features () const |
virtual void | updateFeatures () |
bool | hasFirstParameterDerivative () const |
Returns true when the first parameter derivative is implemented. More... | |
bool | hasSecondParameterDerivative () const |
Returns true when the second parameter derivative is implemented. More... | |
bool | hasFirstInputDerivative () const |
Returns true when the first input derivative is implemented. More... | |
bool | hasSecondInputDerivative () const |
Returns true when the second parameter derivative is implemented. More... | |
bool | isSequential () const |
virtual void | read (InArchive &archive) |
From ISerializable, reads a model from an archive. More... | |
virtual void | write (OutArchive &archive) const |
writes a model to an archive More... | |
virtual void | eval (BatchInputType const &patterns, BatchOutputType &outputs) const |
Standard interface for evaluating the response of the model to a batch of patterns. More... | |
virtual void | eval (BatchInputType const &patterns, BatchOutputType &outputs, State &state) const=0 |
Standard interface for evaluating the response of the model to a batch of patterns. More... | |
virtual void | eval (InputType const &pattern, OutputType &output) const |
Standard interface for evaluating the response of the model to a single pattern. More... | |
Data< OutputType > | operator() (Data< InputType > const &patterns) const |
Model evaluation as an operator for a whole dataset. This is a convenience function. More... | |
OutputType | operator() (InputType const &pattern) const |
Model evaluation as an operator for a single pattern. This is a convenience function. More... | |
BatchOutputType | operator() (BatchInputType const &patterns) const |
Model evaluation as an operator for a single pattern. This is a convenience function. More... | |
virtual void | weightedParameterDerivative (BatchInputType const &pattern, BatchOutputType const &coefficients, State const &state, RealVector &derivative) const |
calculates the weighted sum of derivatives w.r.t the parameters. More... | |
virtual void | weightedParameterDerivative (BatchInputType const &pattern, BatchOutputType const &coefficients, Batch< RealMatrix >::type const &errorHessian, State const &state, RealVector &derivative, RealMatrix &hessian) const |
calculates the weighted sum of derivatives w.r.t the parameters More... | |
virtual void | weightedInputDerivative (BatchInputType const &pattern, BatchOutputType const &coefficients, State const &state, BatchInputType &derivative) const |
calculates the weighted sum of derivatives w.r.t the inputs More... | |
virtual void | weightedInputDerivative (BatchInputType const &pattern, BatchOutputType const &coefficients, typename Batch< RealMatrix >::type const &errorHessian, State const &state, RealMatrix &derivative, Batch< RealMatrix >::type &hessian) const |
calculates the weighted sum of derivatives w.r.t the inputs More... | |
virtual void | weightedDerivatives (BatchInputType const &patterns, BatchOutputType const &coefficients, State const &state, RealVector ¶meterDerivative, BatchInputType &inputDerivative) const |
calculates weighted input and parameter derivative at the same time More... | |
Public Member Functions inherited from shark::IParameterizable | |
virtual | ~IParameterizable () |
Public Member Functions inherited from shark::INameable | |
virtual | ~INameable () |
Public Member Functions inherited from shark::ISerializable | |
virtual | ~ISerializable () |
Virtual d'tor. More... | |
void | load (InArchive &archive, unsigned int version) |
Versioned loading of components, calls read(...). More... | |
void | save (OutArchive &archive, unsigned int version) const |
Versioned storing of components, calls write(...). More... | |
BOOST_SERIALIZATION_SPLIT_MEMBER () | |
Protected Attributes | |
RecurrentStructure * | mpe_structure |
the topology of the network. More... | |
bool | m_computeGradient |
stores whether the network should compute a gradient More... | |
Protected Attributes inherited from shark::AbstractModel< RealVector, RealVector > | |
Features | m_features |
Additional Inherited Members | |
Public Types inherited from shark::AbstractModel< RealVector, RealVector > | |
enum | Feature |
typedef RealVector | InputType |
Defines the input type of the model. More... | |
typedef RealVector | OutputType |
Defines the output type of the model. More... | |
typedef Batch< InputType >::type | BatchInputType |
defines the batch type of the input type. More... | |
typedef Batch< OutputType >::type | BatchOutputType |
defines the batch type of the output type More... | |
typedef TypedFlags< Feature > | Features |
typedef TypedFeatureNotAvailableException< Feature > | FeatureNotAvailableException |
A recurrent neural network regression model optimized for online learning.
The OnlineRNNet can only process a single input at a time. Internally it stores the last activation as well as the derivatives which get updated over the course of the sequence. Instead of feeding in the whole sequence, the inputs must be given one after another. However if the whole sequence is available in advance, this implementation is not advisable, since it is a lot slower than RNNet which is targeted to whole sequences.
All network state is stored in the State structure which can be created by createState() which has to be supplied to eval. A new time sequence is started by generating a new state object. When the network is created the user has to decide whether gradients are needed. In this case additional ressources are allocated in the state object on creation and eval makes sure that the gradient is properly updated between steps, this is costly. It is possible to skip steps updating the parameters, e.g. when no reward signal is available.
Note that eval can only work with batches of size one and eval without a state object can not be called.
Definition at line 62 of file OnlineRNNet.h.
SHARK_EXPORT_SYMBOL shark::OnlineRNNet::OnlineRNNet | ( | RecurrentStructure * | structure, |
bool | computeGradient | ||
) |
structure The structure of the OnlineRNNet
creates a configured neural network
computeGradient Whether the network will be used to compute gradients
|
inlinevirtual |
Creates an internal state of the model.
The state is needed when the derivatives are to be calculated. Eval can store a state which is then reused to speed up the calculations of the derivatives. This also allows eval to be evaluated in parallel!
Reimplemented from shark::AbstractModel< RealVector, RealVector >.
Definition at line 152 of file OnlineRNNet.h.
References mpe_structure.
SHARK_EXPORT_SYMBOL void shark::OnlineRNNet::eval | ( | RealMatrix const & | pattern, |
RealMatrix & | output, | ||
State & | state | ||
) | const |
Feeds a timestep of a time series to the model and calculates it's output. The batches must have size 1.
pattern | Input patterns for the network. |
output | Used to store the outputs of the network. |
state | the current state of the RNN that is updated by eval |
Referenced by shark::GruauPole::balanceFit(), shark::GruauPole::generalFit(), shark::GruauPole::gruauFit(), and name().
|
inline |
It is forbidding to call eval without a state object.
Definition at line 108 of file OnlineRNNet.h.
References SHARKEXCEPTION.
|
inline |
obtain the input dimension
Definition at line 114 of file OnlineRNNet.h.
References shark::RecurrentStructure::inputs(), and mpe_structure.
|
inlinevirtual |
From INameable: return the class name.
Reimplemented from shark::INameable.
Definition at line 95 of file OnlineRNNet.h.
References eval(), and SHARK_EXPORT_SYMBOL.
|
inlinevirtual |
number of parameters of the network
Reimplemented from shark::IParameterizable.
Definition at line 148 of file OnlineRNNet.h.
References mpe_structure, and shark::RecurrentStructure::parameters().
Referenced by shark::GruauPole::GruauPole(), and shark::NonMarkovPole::NonMarkovPole().
|
inline |
obtain the output dimension
Definition at line 119 of file OnlineRNNet.h.
References mpe_structure, shark::RecurrentStructure::outputs(), SHARK_EXPORT_SYMBOL, and weightedParameterDerivative().
Referenced by setOutputActivation().
|
inlinevirtual |
get internal parameters of the model
Reimplemented from shark::IParameterizable.
Definition at line 139 of file OnlineRNNet.h.
References mpe_structure, and shark::RecurrentStructure::parameterVector().
|
inline |
This Method sets the activation of the output neurons.
This is usefull when teacher forcing is used. When the network is trained to predict a timeseries and diverges from the sequence at an early stage, the resulting gradient might not be very helpfull. In this case, teacher forcing can be applied to prevent diverging. However, the network might become unstable, when teacher-forcing is turned off because there is no force which prevents it from diverging anymore.
state | The current state of the network |
activation | Input patterns for the network. |
Definition at line 168 of file OnlineRNNet.h.
References mpe_structure, shark::RecurrentStructure::numberOfUnits(), outputSize(), remora::subrange(), and shark::State::toState().
|
inlinevirtual |
set internal parameters of the model
Reimplemented from shark::IParameterizable.
Definition at line 143 of file OnlineRNNet.h.
References mpe_structure, and shark::RecurrentStructure::setParameterVector().
Referenced by shark::GruauPole::balanceFit(), shark::GruauPole::generalFit(), and shark::GruauPole::gruauFit().
SHARK_EXPORT_SYMBOL void shark::OnlineRNNet::weightedParameterDerivative | ( | RealMatrix const & | pattern, |
RealMatrix const & | coefficients, | ||
State const & | state, | ||
RealVector & | gradient | ||
) | const |
calculates the weighted sum of gradients w.r.t the parameters
Uses an iterative update scheme to calculate the gradient at timestep t from the gradient at timestep t-1 using forward propagation. This Methods requires O(n^3) Memory and O(n^4) computations, where n is the number of neurons. So if the network is very large, RNNet should be used!
pattern | the pattern to evaluate |
coefficients | the oefficients which are used to calculate the weighted sum |
gradient | the calculated gradient |
state | the current state of the RNN |
Referenced by outputSize().
|
protected |
stores whether the network should compute a gradient
Definition at line 182 of file OnlineRNNet.h.
|
protected |
the topology of the network.
Definition at line 179 of file OnlineRNNet.h.
Referenced by createState(), inputSize(), numberOfParameters(), outputSize(), parameterVector(), setOutputActivation(), and setParameterVector().