AIToolbox
A library that offers tools for AI problem solving.
AIToolbox::MDP::ExpectedSARSA Class Reference

This class represents the ExpectedSARSA algorithm. More...

#include <AIToolbox/MDP/Algorithms/ExpectedSARSA.hpp>

Public Member Functions

 ExpectedSARSA (QFunction &qfun, const PolicyInterface &policy, double discount=0.0, double alpha=0.1)
 Basic constructor. More...
 
template<IsGenerativeModel M>
 ExpectedSARSA (QFunction &qfun, const PolicyInterface &policy, const M &model, double alpha=0.1)
 Basic constructor. More...
 
void setLearningRate (double a)
 This function sets the learning rate parameter. More...
 
double getLearningRate () const
 This function will return the current set learning rate parameter. More...
 
void setDiscount (double d)
 This function sets the new discount parameter. More...
 
double getDiscount () const
 This function returns the currently set discount parameter. More...
 
void stepUpdateQ (size_t s, size_t a, size_t s1, double rew)
 This function updates the internal QFunction using the discount set during construction. More...
 
size_t getS () const
 This function returns the number of states on which QLearning is working. More...
 
size_t getA () const
 This function returns the number of actions on which QLearning is working. More...
 
const QFunctiongetQFunction () const
 This function returns a reference to the internal QFunction. More...
 
const PolicyInterfacegetPolicy () const
 This function returns a reference to the policy used by ExpectedSARSA. More...
 

Detailed Description

This class represents the ExpectedSARSA algorithm.

This algorithm is a subtle improvement over the SARSA algorithm.

See also
SARSA

The difference between this algorithm and the original SARSA algorithm lies in the value used to approximate the value for the next timestep. In standard SARSA this value is directly taken as the current approximation of the value of the QFunction for the newly sampled state and the next action to be performed (the final "SA" in SAR"SA").

In Expected SARSA this value is instead replaced by the expected value for the newly sampled state, given the policy from which we will sample the next action. In this sense Expected SARSA is more similar to QLearning: where QLearning uses the max over the QFunction for the next state, Expected SARSA uses the future expectation over the current online policy.

This reduces considerably the variance of the updates performed, which in turn allows to somewhat increase the learning rate for the method, which allows Expected SARSA to learn faster than simple SARSA. All guarantees of normal SARSA are maintained.

Constructor & Destructor Documentation

◆ ExpectedSARSA() [1/2]

AIToolbox::MDP::ExpectedSARSA::ExpectedSARSA ( QFunction qfun,
const PolicyInterface policy,
double  discount = 0.0,
double  alpha = 0.1 
)

Basic constructor.

Note that differently from normal SARSA, ExpectedSARSA does not self-contain its own QFunction. This is because many policies are implemented in terms of a QFunction continuously updated by a method (e.g. QGreedyPolicy).

At the same time ExpectedSARSA needs this policy in order to be able to perform its expected value computation. In order to avoid having a chicken and egg problem, ExpectedSARSA takes a QFunction as parameter to allow the user to create it an use the same one for both ExpectedSARSA and the policy.

The learning rate must be > 0.0 and <= 1.0, otherwise the constructor will throw an std::invalid_argument.

Parameters
qfunThe QFunction underlying the ExpectedSARSA algorithm.
policyThe policy used to select actions.
discountThe discount of the underlying MDP model.
alphaThe learning rate of the ExpectedSARSA method.

◆ ExpectedSARSA() [2/2]

template<IsGenerativeModel M>
AIToolbox::MDP::ExpectedSARSA::ExpectedSARSA ( QFunction qfun,
const PolicyInterface policy,
const M &  model,
double  alpha = 0.1 
)

Basic constructor.

Note that differently from normal SARSA, ExpectedSARSA does not self-contain its own QFunction. This is because many policies are implemented in terms of a QFunction continuously updated by a method (e.g. QGreedyPolicy).

At the same time ExpectedSARSA needs this policy in order to be able to perform its expected value computation. In order to avoid having a chicken and egg problem, ExpectedSARSA takes a QFunction as parameter to allow the user to create it an use the same one for both ExpectedSARSA and the policy.

The learning rate must be > 0.0 and <= 1.0, otherwise the constructor will throw an std::invalid_argument.

This constructor copies the discount parameter from the supplied model. It does not keep the reference, so if the discount needs to change you'll need to update it here manually too.

Parameters
qfunThe QFunction underlying the ExpectedSARSA algorithm.
policyThe policy used to select actions.
modelThe MDP model that ExpectedSARSA will use as a base.
alphaThe learning rate of the ExpectedSARSA method.

Member Function Documentation

◆ getA()

size_t AIToolbox::MDP::ExpectedSARSA::getA ( ) const

This function returns the number of actions on which QLearning is working.

Returns
The number of actions.

◆ getDiscount()

double AIToolbox::MDP::ExpectedSARSA::getDiscount ( ) const

This function returns the currently set discount parameter.

Returns
The currently set discount parameter.

◆ getLearningRate()

double AIToolbox::MDP::ExpectedSARSA::getLearningRate ( ) const

This function will return the current set learning rate parameter.

Returns
The currently set learning rate parameter.

◆ getPolicy()

const PolicyInterface& AIToolbox::MDP::ExpectedSARSA::getPolicy ( ) const

This function returns a reference to the policy used by ExpectedSARSA.

Returns
The internal policy reference.

◆ getQFunction()

const QFunction& AIToolbox::MDP::ExpectedSARSA::getQFunction ( ) const

This function returns a reference to the internal QFunction.

The returned reference can be used to build Policies, for example MDP::QGreedyPolicy.

Returns
The internal QFunction.

◆ getS()

size_t AIToolbox::MDP::ExpectedSARSA::getS ( ) const

This function returns the number of states on which QLearning is working.

Returns
The number of states.

◆ setDiscount()

void AIToolbox::MDP::ExpectedSARSA::setDiscount ( double  d)

This function sets the new discount parameter.

The discount parameter controls the amount that future rewards are considered by ExpectedSARSA. If 1, then any reward is the same, if obtained now or in a million timesteps. Thus the algorithm will optimize overall reward accretion. When less than 1, rewards obtained in the presents are valued more than future rewards.

Parameters
dThe new discount factor.

◆ setLearningRate()

void AIToolbox::MDP::ExpectedSARSA::setLearningRate ( double  a)

This function sets the learning rate parameter.

The learning parameter determines the speed at which the QFunction is modified with respect to new data. In fully deterministic environments (such as an agent moving through a grid, for example), this parameter can be safely set to 1.0 for maximum learning.

On the other side, in stochastic environments, in order to converge this parameter should be higher when first starting to learn, and decrease slowly over time.

Otherwise it can be kept somewhat high if the environment dynamics change progressively, and the algorithm will adapt accordingly. The final behaviour of ExpectedSARSA is very dependent on this parameter.

The learning rate parameter must be > 0.0 and <= 1.0, otherwise the function will throw an std::invalid_argument.

Parameters
aThe new learning rate parameter.

◆ stepUpdateQ()

void AIToolbox::MDP::ExpectedSARSA::stepUpdateQ ( size_t  s,
size_t  a,
size_t  s1,
double  rew 
)

This function updates the internal QFunction using the discount set during construction.

This function takes a single experience point and uses it to update the QFunction. This is a very efficient method to keep the QFunction up to date with the latest experience.

Keep in mind that, since ExpectedSARSA needs to compute the QFunction for the currently used policy, it needs to know two consecutive state-action pairs, in order to correctly relate how the policy acts from state to state.

Parameters
sThe previous state.
aThe action performed.
s1The new state.
rewThe reward obtained.

The documentation for this class was generated from the following file: