AIToolbox
A library that offers tools for AI problem solving.
|
This class represents the Dyna2 algorithm. More...
#include <AIToolbox/MDP/Algorithms/Dyna2.hpp>
Public Member Functions | |
Dyna2 (const M &m, double alpha=0.1, double lambda=0.9, double tolerance=0.001, unsigned n=50) | |
Basic constructor. More... | |
void | stepUpdateQ (size_t s, size_t a, size_t s1, size_t a1, double rew) |
This function updates the internal QFunction. More... | |
void | batchUpdateQ (size_t s) |
This function updates a QFunction based on simulated experience. More... | |
void | resetTransientLearning () |
This function resets the transient QFunction to the permanent one. More... | |
void | setInternalPolicy (PolicyInterface *p) |
This function sets the policy used to sample during batch updates. More... | |
void | setPermanentLambda (double l) |
This function sets the new lambda parameter for the permanent SARSAL. More... | |
double | getPermanentLambda () const |
This function returns the currently set lambda parameter for the permanent SARSAL. More... | |
void | setTransientLambda (double l) |
This function sets the new lambda parameter for the transient SARSAL. More... | |
double | getTransientLambda () const |
This function returns the currently set lambda parameter for the transient SARSAL. More... | |
void | setN (unsigned n) |
This function sets the current sample number parameter. More... | |
unsigned | getN () const |
This function returns the currently set number of sampling passes during batchUpdateQ(). More... | |
void | setTolerance (double t) |
This function sets the trace cutoff parameter. More... | |
double | getTolerance () const |
This function returns the currently set trace cutoff parameter. More... | |
const QFunction & | getPermanentQFunction () const |
This function returns a reference to the internal permanent QFunction. More... | |
const QFunction & | getTransientQFunction () const |
This function returns a reference to the internal transient QFunction. More... | |
const M & | getModel () const |
This function returns a reference to the referenced Model. More... | |
This class represents the Dyna2 algorithm.
This algorithm leverages the SARSAL algorithm in order to keep two separate QFunctions: one permanent, and one transient.
The permanent one contains the QFunction learned when actually interacting with the real environment. The transient one is instead used to learn against a generative model, so that it can explore.
The transient one is overall always a sum of the permanent one and whatever it learns during batch exploration. After each episode, the transient memory should be cleared in order to avoid storing information about states that it may never again encounter.
Another advantage of clearing the memory is that, if the exploration model is not perfect, imperfect information learned is also discarded.
|
explicit |
Basic constructor.
m | The model to be used to update the QFunction. |
alpha | The learning rate of the internal SARSAL methods. |
lambda | The lambda parameter for the eligibility traces. |
tolerance | The cutoff point for eligibility traces. |
n | The number of sampling passes to do on the model upon batchUpdateQ(). |
void AIToolbox::MDP::Dyna2< M >::batchUpdateQ | ( | size_t | s | ) |
This function updates a QFunction based on simulated experience.
In Dyna2 we sample N times from already experienced state-action pairs, and we update the resulting QFunction as if this experience was actually real.
The idea is that since we know which state action pairs we already explored, we know that whose pairs are actually possible. Thus we use the generative model to sample them again, and obtain a better estimate of the QFunction.
const M & AIToolbox::MDP::Dyna2< M >::getModel |
unsigned AIToolbox::MDP::Dyna2< M >::getN |
This function returns the currently set number of sampling passes during batchUpdateQ().
double AIToolbox::MDP::Dyna2< M >::getPermanentLambda |
const QFunction & AIToolbox::MDP::Dyna2< M >::getPermanentQFunction |
This function returns a reference to the internal permanent QFunction.
double AIToolbox::MDP::Dyna2< M >::getTolerance |
This function returns the currently set trace cutoff parameter.
double AIToolbox::MDP::Dyna2< M >::getTransientLambda |
const QFunction & AIToolbox::MDP::Dyna2< M >::getTransientQFunction |
This function returns a reference to the internal transient QFunction.
void AIToolbox::MDP::Dyna2< M >::resetTransientLearning |
This function resets the transient QFunction to the permanent one.
void AIToolbox::MDP::Dyna2< M >::setInternalPolicy | ( | PolicyInterface * | p | ) |
This function sets the policy used to sample during batch updates.
This function is provided separately in case you want to base the policy on either the permanent or transient QFunctions, which are internally owned and thus do not exist before this class is actually created.
This function takes ownership of the input policy, and destroys the previous one.
p | The new policy to use during batch updates. |
void AIToolbox::MDP::Dyna2< M >::setN | ( | unsigned | n | ) |
This function sets the current sample number parameter.
n | The new sample number parameter. |
void AIToolbox::MDP::Dyna2< M >::setPermanentLambda | ( | double | l | ) |
This function sets the new lambda parameter for the permanent SARSAL.
This parameter determines how much to decrease updates for each timestep in the past.
The lambda parameter must be >= 0.0 and <= 1.0, otherwise the function will throw an std::invalid_argument.
l | The new lambda parameter. |
void AIToolbox::MDP::Dyna2< M >::setTolerance | ( | double | t | ) |
void AIToolbox::MDP::Dyna2< M >::setTransientLambda | ( | double | l | ) |
This function sets the new lambda parameter for the transient SARSAL.
This parameter determines how much to decrease updates for each timestep in the past.
The lambda parameter must be >= 0.0 and <= 1.0, otherwise the function will throw an std::invalid_argument.
l | The new lambda parameter. |
void AIToolbox::MDP::Dyna2< M >::stepUpdateQ | ( | size_t | s, |
size_t | a, | ||
size_t | s1, | ||
size_t | a1, | ||
double | rew | ||
) |
This function updates the internal QFunction.
This function takes a single experience point and uses it to update a QFunction. This is a very efficient method to keep the QFunction up to date with the latest experience.
In addition, the sampling list is updated so that batch updating becomes possible as a second phase.
The sampling list in Dyna2 is a simple list of all visited state action pairs. This function is responsible for inserting them in a set, keeping them unique.
s | The previous state. |
a | The action performed. |
s1 | The new state. |
a1 | The action performed in the new state. |
rew | The reward obtained. |