AIToolbox
A library that offers tools for AI problem solving.
SARSA.hpp
Go to the documentation of this file.
1 #ifndef AI_TOOLBOX_MDP_SARSA_HEADER_FILE
2 #define AI_TOOLBOX_MDP_SARSA_HEADER_FILE
3 
4 #include <stddef.h>
5 
9 
10 namespace AIToolbox::MDP {
48  class SARSA {
49  public:
61  SARSA(size_t S, size_t A, double discount = 1.0, double alpha = 0.1);
62 
77  template <IsGenerativeModel M>
78  SARSA(const M& model, double alpha = 0.1);
79 
103  void setLearningRate(double a);
104 
110  double getLearningRate() const;
111 
122  void setDiscount(double d);
123 
129  double getDiscount() const;
130 
149  void stepUpdateQ(size_t s, size_t a, size_t s1, size_t a1, double rew);
150 
156  size_t getS() const;
157 
163  size_t getA() const;
164 
173  const QFunction & getQFunction() const;
174 
175  private:
176  size_t S, A;
177  double alpha_;
178  double discount_;
179 
180  QFunction q_;
181  };
182 
183  template <IsGenerativeModel M>
184  SARSA::SARSA(const M& model, const double alpha) :
185  SARSA(model.getS(), model.getA(), model.getDiscount(), alpha) {}
186 }
187 #endif
AIToolbox::MDP::SARSA::setDiscount
void setDiscount(double d)
This function sets the new discount parameter.
AIToolbox::MDP::SARSA::getDiscount
double getDiscount() const
This function returns the currently set discount parameter.
AIToolbox::MDP::SARSA::getA
size_t getA() const
This function returns the number of actions on which QLearning is working.
AIToolbox::MDP::QFunction
Matrix2D QFunction
Definition: Types.hpp:52
AIToolbox::MDP::SARSA
This class represents the SARSA algorithm.
Definition: SARSA.hpp:48
AIToolbox::MDP::SARSA::setLearningRate
void setLearningRate(double a)
This function sets the learning rate parameter.
AIToolbox::MDP
Definition: DoubleQLearning.hpp:10
Utils.hpp
AIToolbox::MDP::SARSA::getQFunction
const QFunction & getQFunction() const
This function returns a reference to the internal QFunction.
AIToolbox::MDP::SARSA::stepUpdateQ
void stepUpdateQ(size_t s, size_t a, size_t s1, size_t a1, double rew)
This function updates the internal QFunction using the discount set during construction.
AIToolbox::MDP::SARSA::getLearningRate
double getLearningRate() const
This function will return the current set learning rate parameter.
AIToolbox::MDP::SARSA::SARSA
SARSA(size_t S, size_t A, double discount=1.0, double alpha=0.1)
Basic constructor.
AIToolbox::MDP::SARSA::getS
size_t getS() const
This function returns the number of states on which QLearning is working.
Types.hpp
TypeTraits.hpp