AIToolbox
A library that offers tools for AI problem solving.
SARSAL.hpp
Go to the documentation of this file.
1 #ifndef AI_TOOLBOX_MDP_SARSAL_HEADER_FILE
2 #define AI_TOOLBOX_MDP_SARSAL_HEADER_FILE
3 
4 #include <stddef.h>
5 
9 
10 namespace AIToolbox::MDP {
37  class SARSAL {
38  public:
39  using Trace = std::tuple<size_t, size_t, double>;
40  using Traces = std::vector<Trace>;
54  SARSAL(size_t S, size_t A, double discount = 1.0, double alpha = 0.1, double lambda = 0.9, double tolerance = 0.001);
55 
72  template <IsGenerativeModel M>
73  SARSAL(const M& model, double alpha = 0.1, double lambda = 0.9, double tolerance = 0.001);
74 
93  void stepUpdateQ(size_t s, size_t a, size_t s1, size_t a1, double rew);
94 
118  void setLearningRate(double a);
119 
125  double getLearningRate() const;
126 
137  void setDiscount(double d);
138 
144  double getDiscount() const;
145 
162  void setLambda(double l);
163 
169  double getLambda() const;
170 
183  void setTolerance(double t);
184 
190  double getTolerance() const;
191 
195  void clearTraces();
196 
202  const Traces & getTraces() const;
203 
213  void setTraces(const Traces & t);
214 
220  size_t getS() const;
221 
227  size_t getA() const;
228 
237  const QFunction & getQFunction() const;
238 
250  void setQFunction(const QFunction & qfun);
251 
252  private:
253  size_t S, A;
254  double alpha_;
255  double discount_;
256  double lambda_, tolerance_;
257  // This is used to avoid multiplying the discount and lambda all the time.
258  double gammaL_;
259 
260  QFunction q_;
261  Traces traces_;
262  };
263 
264  template <IsGenerativeModel M>
265  SARSAL::SARSAL(const M& model, const double alpha, const double lambda, const double tolerance) :
266  SARSAL(model.getS(), model.getA(), model.getDiscount(), alpha, lambda, tolerance) {}
267 }
268 #endif
AIToolbox::MDP::SARSAL::Traces
std::vector< Trace > Traces
Definition: SARSAL.hpp:40
AIToolbox::MDP::SARSAL::setTolerance
void setTolerance(double t)
This function sets the trace cutoff parameter.
AIToolbox::MDP::SARSAL::getLambda
double getLambda() const
This function returns the currently set lambda parameter.
AIToolbox::MDP::SARSAL::clearTraces
void clearTraces()
This function clears the already set traces.
AIToolbox::MDP::QFunction
Matrix2D QFunction
Definition: Types.hpp:52
AIToolbox::MDP::SARSAL::getTolerance
double getTolerance() const
This function returns the currently set trace cutoff parameter.
AIToolbox::MDP::SARSAL::getQFunction
const QFunction & getQFunction() const
This function returns a reference to the internal QFunction.
AIToolbox::MDP::SARSAL::stepUpdateQ
void stepUpdateQ(size_t s, size_t a, size_t s1, size_t a1, double rew)
This function updates the internal QFunction using the discount set during construction.
AIToolbox::MDP::SARSAL::getTraces
const Traces & getTraces() const
This function returns the currently set traces.
AIToolbox::MDP
Definition: DoubleQLearning.hpp:10
AIToolbox::MDP::SARSAL::setQFunction
void setQFunction(const QFunction &qfun)
This function allows to directly set the internal QFunction.
AIToolbox::MDP::SARSAL::getLearningRate
double getLearningRate() const
This function will return the current set learning rate parameter.
Utils.hpp
AIToolbox::MDP::SARSAL::Trace
std::tuple< size_t, size_t, double > Trace
Definition: SARSAL.hpp:39
AIToolbox::MDP::SARSAL::getA
size_t getA() const
This function returns the number of actions on which QLearning is working.
AIToolbox::MDP::SARSAL::setTraces
void setTraces(const Traces &t)
This function sets the currently set traces.
AIToolbox::MDP::SARSAL::getS
size_t getS() const
This function returns the number of states on which QLearning is working.
AIToolbox::MDP::SARSAL::setDiscount
void setDiscount(double d)
This function sets the new discount parameter.
AIToolbox::MDP::SARSAL
This class represents the SARSAL algorithm.
Definition: SARSAL.hpp:37
Types.hpp
AIToolbox::MDP::SARSAL::setLambda
void setLambda(double l)
This function sets the new lambda parameter.
TypeTraits.hpp
AIToolbox::MDP::SARSAL::SARSAL
SARSAL(size_t S, size_t A, double discount=1.0, double alpha=0.1, double lambda=0.9, double tolerance=0.001)
Basic constructor.
AIToolbox::MDP::SARSAL::getDiscount
double getDiscount() const
This function returns the currently set discount parameter.
AIToolbox::MDP::SARSAL::setLearningRate
void setLearningRate(double a)
This function sets the learning rate parameter.