AIToolbox
A library that offers tools for AI problem solving.
AIToolbox::MDP::SARSAL Class Reference

This class represents the SARSAL algorithm. More...

#include <AIToolbox/MDP/Algorithms/SARSAL.hpp>

Public Types

using Trace = std::tuple< size_t, size_t, double >
 
using Traces = std::vector< Trace >
 

Public Member Functions

 SARSAL (size_t S, size_t A, double discount=1.0, double alpha=0.1, double lambda=0.9, double tolerance=0.001)
 Basic constructor. More...
 
template<IsGenerativeModel M>
 SARSAL (const M &model, double alpha=0.1, double lambda=0.9, double tolerance=0.001)
 Basic constructor. More...
 
void stepUpdateQ (size_t s, size_t a, size_t s1, size_t a1, double rew)
 This function updates the internal QFunction using the discount set during construction. More...
 
void setLearningRate (double a)
 This function sets the learning rate parameter. More...
 
double getLearningRate () const
 This function will return the current set learning rate parameter. More...
 
void setDiscount (double d)
 This function sets the new discount parameter. More...
 
double getDiscount () const
 This function returns the currently set discount parameter. More...
 
void setLambda (double l)
 This function sets the new lambda parameter. More...
 
double getLambda () const
 This function returns the currently set lambda parameter. More...
 
void setTolerance (double t)
 This function sets the trace cutoff parameter. More...
 
double getTolerance () const
 This function returns the currently set trace cutoff parameter. More...
 
void clearTraces ()
 This function clears the already set traces. More...
 
const TracesgetTraces () const
 This function returns the currently set traces. More...
 
void setTraces (const Traces &t)
 This function sets the currently set traces. More...
 
size_t getS () const
 This function returns the number of states on which QLearning is working. More...
 
size_t getA () const
 This function returns the number of actions on which QLearning is working. More...
 
const QFunctiongetQFunction () const
 This function returns a reference to the internal QFunction. More...
 
void setQFunction (const QFunction &qfun)
 This function allows to directly set the internal QFunction. More...
 

Detailed Description

This class represents the SARSAL algorithm.

This algorithms adds eligibility traces to the SARSA algorithm.

See also
SARSA

In order to more effectively use the data obtained, SARSAL keeps a list of previously visited state/action pairs, which are updated together with the last experienced transition. The updates all use the same value, with the difference that state/action pairs experienced more in the past are updated less (by discount*lambda per each previous timestep). Once this reducing coefficient falls below a certain threshold, the old state/action pair is forgotten and not updated anymore. If instead the pair is visited again, the coefficient is once again increased.

The idea is to be able to give credit to past actions for current reward in an efficient manner. This reduces the amount of data needed in order to backpropagate rewards, and allows SARSAL to learn faster.

This particular version of the algorithm implements capped traces: every time an action/state pair is witnessed, its eligibility trace is reset to 1.0. This avoids potentially diverging values which can happen with the normal eligibility traces.

Member Typedef Documentation

◆ Trace

using AIToolbox::MDP::SARSAL::Trace = std::tuple<size_t, size_t, double>

◆ Traces

using AIToolbox::MDP::SARSAL::Traces = std::vector<Trace>

Constructor & Destructor Documentation

◆ SARSAL() [1/2]

AIToolbox::MDP::SARSAL::SARSAL ( size_t  S,
size_t  A,
double  discount = 1.0,
double  alpha = 0.1,
double  lambda = 0.9,
double  tolerance = 0.001 
)

Basic constructor.

The learning rate must be > 0.0 and <= 1.0, otherwise the constructor will throw an std::invalid_argument.

Parameters
SThe state space of the underlying model.
AThe action space of the underlying model.
discountThe discount of the underlying model.
alphaThe learning rate of the SARSAL method.
lambdaThe lambda parameter for the eligibility traces.
toleranceThe cutoff point for eligibility traces.

◆ SARSAL() [2/2]

template<IsGenerativeModel M>
AIToolbox::MDP::SARSAL::SARSAL ( const M &  model,
double  alpha = 0.1,
double  lambda = 0.9,
double  tolerance = 0.001 
)

Basic constructor.

The learning rate must be > 0.0 and <= 1.0, otherwise the constructor will throw an std::invalid_argument.

This constructor copies the S and A and discount parameters from the supplied model. It does not keep the reference, so if the discount needs to change you'll need to update it here manually too.

Parameters
modelThe MDP model that SARSAL will use as a base.
alphaThe learning rate of the SARSAL method.
lambdaThe lambda parameter for the eligibility traces.
toleranceThe cutoff point for eligibility traces.

Member Function Documentation

◆ clearTraces()

void AIToolbox::MDP::SARSAL::clearTraces ( )

This function clears the already set traces.

◆ getA()

size_t AIToolbox::MDP::SARSAL::getA ( ) const

This function returns the number of actions on which QLearning is working.

Returns
The number of actions.

◆ getDiscount()

double AIToolbox::MDP::SARSAL::getDiscount ( ) const

This function returns the currently set discount parameter.

Returns
The currently set discount parameter.

◆ getLambda()

double AIToolbox::MDP::SARSAL::getLambda ( ) const

This function returns the currently set lambda parameter.

Returns
The currently set lambda parameter.

◆ getLearningRate()

double AIToolbox::MDP::SARSAL::getLearningRate ( ) const

This function will return the current set learning rate parameter.

Returns
The currently set learning rate parameter.

◆ getQFunction()

const QFunction& AIToolbox::MDP::SARSAL::getQFunction ( ) const

This function returns a reference to the internal QFunction.

The returned reference can be used to build Policies, for example MDP::QGreedyPolicy.

Returns
The internal QFunction.

◆ getS()

size_t AIToolbox::MDP::SARSAL::getS ( ) const

This function returns the number of states on which QLearning is working.

Returns
The number of states.

◆ getTolerance()

double AIToolbox::MDP::SARSAL::getTolerance ( ) const

This function returns the currently set trace cutoff parameter.

Returns
The currently set trace cutoff parameter.

◆ getTraces()

const Traces& AIToolbox::MDP::SARSAL::getTraces ( ) const

This function returns the currently set traces.

Returns
The currently set traces.

◆ setDiscount()

void AIToolbox::MDP::SARSAL::setDiscount ( double  d)

This function sets the new discount parameter.

The discount parameter controls the amount that future rewards are considered by SARSAL. If 1, then any reward is the same, if obtained now or in a million timesteps. Thus the algorithm will optimize overall reward accretion. When less than 1, rewards obtained in the presents are valued more than future rewards.

Parameters
dThe new discount factor.

◆ setLambda()

void AIToolbox::MDP::SARSAL::setLambda ( double  l)

This function sets the new lambda parameter.

This parameter determines how much to decrease updates for each timestep in the past. If set to zero, SARSAL effectively becomes equivalent to SARSA, as no backpropagation will be performed. If set to 1 it will result in a method similar to Monte Carlo sampling, where rewards are backed up from the end to the beginning of the episode (of course still dependent on the discount of the model).

The lambda parameter must be >= 0.0 and <= 1.0, otherwise the function will throw an std::invalid_argument.

Parameters
lThe new lambda parameter.

◆ setLearningRate()

void AIToolbox::MDP::SARSAL::setLearningRate ( double  a)

This function sets the learning rate parameter.

The learning parameter determines the speed at which the QFunction is modified with respect to new data. In fully deterministic environments (such as an agent moving through a grid, for example), this parameter can be safely set to 1.0 for maximum learning.

On the other side, in stochastic environments, in order to converge this parameter should be higher when first starting to learn, and decrease slowly over time.

Otherwise it can be kept somewhat high if the environment dynamics change progressively, and the algorithm will adapt accordingly. The final behaviour of SARSAL is very dependent on this parameter.

The learning rate parameter must be > 0.0 and <= 1.0, otherwise the function will throw an std::invalid_argument.

Parameters
aThe new learning rate parameter.

◆ setQFunction()

void AIToolbox::MDP::SARSAL::setQFunction ( const QFunction qfun)

This function allows to directly set the internal QFunction.

This can be useful in order to use a QFunction that has already been computed elsewhere. SARSAL will then continue building upon it.

This is used for example in the Dyna2 algorithm.

Parameters
qfunThe new QFunction to set.

◆ setTolerance()

void AIToolbox::MDP::SARSAL::setTolerance ( double  t)

This function sets the trace cutoff parameter.

This parameter determines when a trace is removed, as its coefficient has become too small to bother updating its value.

Note that the trace cutoff is performed on the overall discount*lambda value, and not only on lambda. So this parameter is useful even when lambda is 1.

Parameters
tThe new trace cutoff value.

◆ setTraces()

void AIToolbox::MDP::SARSAL::setTraces ( const Traces t)

This function sets the currently set traces.

This method is provided in case you have a need to tinker with the internal traces. You generally don't unless you are building on top of SARSAL in order to do something more complicated.

Parameters
tThe currently set traces.

◆ stepUpdateQ()

void AIToolbox::MDP::SARSAL::stepUpdateQ ( size_t  s,
size_t  a,
size_t  s1,
size_t  a1,
double  rew 
)

This function updates the internal QFunction using the discount set during construction.

This function takes a single experience point and uses it to update the QFunction. This is a very efficient method to keep the QFunction up to date with the latest experience.

Keep in mind that, since SARSAL needs to compute the QFunction for the currently used policy, it needs to know two consecutive state-action pairs, in order to correctly relate how the policy acts from state to state.

Parameters
sThe previous state.
aThe action performed.
s1The new state.
a1The action performed in the new state.
rewThe reward obtained.

The documentation for this class was generated from the following file: