AIToolbox
A library that offers tools for AI problem solving.
Dyna2.hpp
Go to the documentation of this file.
1 #ifndef AI_TOOLBOX_MDP_DYNA2_HEADER_FILE
2 #define AI_TOOLBOX_MDP_DYNA2_HEADER_FILE
3 
9 
10 namespace AIToolbox::MDP {
29  template <IsGenerativeModel M>
30  class Dyna2 {
31  public:
41  explicit Dyna2(const M & m, double alpha = 0.1, double lambda = 0.9, double tolerance = 0.001, unsigned n = 50);
42 
63  void stepUpdateQ(size_t s, size_t a, size_t s1, size_t a1, double rew);
64 
77  void batchUpdateQ(size_t s);
78 
83 
98 
112  void setPermanentLambda(double l);
113 
119  double getPermanentLambda() const;
120 
134  void setTransientLambda(double l);
135 
141  double getTransientLambda() const;
142 
148  void setN(unsigned n);
149 
155  unsigned getN() const;
156 
170  void setTolerance(double t);
171 
177  double getTolerance() const;
178 
184  const QFunction & getPermanentQFunction() const;
185 
191  const QFunction & getTransientQFunction() const;
192 
198  const M & getModel() const;
199 
200  private:
201  unsigned N;
202  const M & model_;
203  SARSAL permanentLearning_;
204  SARSAL transientLearning_;
205  std::unique_ptr<PolicyInterface> internalPolicy_;
206  };
207 
208  template <IsGenerativeModel M>
209  Dyna2<M>::Dyna2(const M & m, const double alpha, const double lambda, const double tolerance, const unsigned n) :
210  N(n), model_(m),
211  permanentLearning_(model_, alpha, lambda, tolerance),
212  transientLearning_(model_, alpha, lambda, tolerance),
213  internalPolicy_(new BanditPolicyAdaptor<Bandit::RandomPolicy>(model_.getS(), model_.getA()))
214  {
215  }
216 
217  template <IsGenerativeModel M>
218  void Dyna2<M>::stepUpdateQ(const size_t s, const size_t a, const size_t s1, const size_t a1, const double rew) {
219  // We copy the traces from the permanent SARSAL to the transient one so
220  // that they will update their respective QFunctions in (nearly) the
221  // same way.
222  //
223  // Note that this is not quite the same as it is stated in the paper.
224  // Normally one would update only permanentLearning_, and transfer the
225  // exact same changes directly to the QFunction of transientLearning_.
226  //
227  // They differ since the QFunction inside each method are different,
228  // and so the updates won't exactly match. At the same time, after each
229  // reset (or end of episodes) the transient memory should reset to the
230  // permanent one, so this minor differences should go away.
231  //
232  // Ideally one would update directly the two QFunctions here, but this
233  // would basically require re-implementing SARSAL both here and in the
234  // batchUpdateQ method, which we avoid here for practicality.
235  transientLearning_.setTraces(permanentLearning_.getTraces());
236  permanentLearning_.stepUpdateQ(s, a, s1, a1, rew);
237  transientLearning_.stepUpdateQ(s, a, s1, a1, rew);
238  }
239 
240  template <IsGenerativeModel M>
241  void Dyna2<M>::batchUpdateQ(const size_t initS) {
242  // This clearing may not be needed if this is called after stepUpdateQ
243  // with the same s1 (since the set traces there will be correct then).
244  // We do it anyway in case this method is called in different settings
245  // and/or multiple times in a row.
246  transientLearning_.clearTraces();
247 
248  size_t s = initS;
249  size_t a = internalPolicy_->sampleAction(s);
250  for ( unsigned i = 0; i < N; ++i ) {
251  const auto [s1, rew] = model_.sampleSR(s, a);
252  const size_t a1 = internalPolicy_->sampleAction(s1);
253 
254  transientLearning_.stepUpdateQ(s, a, s1, a1, rew);
255 
256  if (model_.isTerminal(s1)) {
257  s = initS;
258  a = internalPolicy_->sampleAction(s);
259  } else {
260  s = s1;
261  a = a1;
262  }
263  }
264  }
265 
266  template <IsGenerativeModel M>
268  transientLearning_.setQFunction(permanentLearning_.getQFunction());
269  }
270  template <IsGenerativeModel M>
272  internalPolicy_.reset(p);
273  }
274 
275  template <IsGenerativeModel M>
276  unsigned Dyna2<M>::getN() const {
277  return N;
278  }
279 
280  template <IsGenerativeModel M>
281  void Dyna2<M>::setTolerance(const double t) {
282  transientLearning_.setTolerance(t);
283  permanentLearning_.setTolerance(t);
284  }
285 
286  template <IsGenerativeModel M>
287  double Dyna2<M>::getTolerance() const {
288  return permanentLearning_.getTolerance();
289  }
290 
291  template <IsGenerativeModel M>
293  return permanentLearning_.getQFunction();
294  }
295 
296  template <IsGenerativeModel M>
298  return transientLearning_.getQFunction();
299  }
300 
301  template <IsGenerativeModel M>
302  const M & Dyna2<M>::getModel() const {
303  return model_;
304  }
305 
306  template <IsGenerativeModel M>
307  void Dyna2<M>::setPermanentLambda(double l) { permanentLearning_.setLambda(l); }
308  template <IsGenerativeModel M>
309  double Dyna2<M>::getPermanentLambda() const { return permanentLearning_.getLambda(); }
310  template <IsGenerativeModel M>
311  void Dyna2<M>::setTransientLambda(double l) { transientLearning_.setLambda(l); }
312  template <IsGenerativeModel M>
313  double Dyna2<M>::getTransientLambda() const { return transientLearning_.getLambda(); }
314 }
315 
316 #endif
AIToolbox::MDP::Dyna2::getN
unsigned getN() const
This function returns the currently set number of sampling passes during batchUpdateQ().
Definition: Dyna2.hpp:276
AIToolbox::MDP::Dyna2::getPermanentQFunction
const QFunction & getPermanentQFunction() const
This function returns a reference to the internal permanent QFunction.
Definition: Dyna2.hpp:292
AIToolbox::MDP::Dyna2::resetTransientLearning
void resetTransientLearning()
This function resets the transient QFunction to the permanent one.
Definition: Dyna2.hpp:267
AIToolbox::MDP::Dyna2::getTolerance
double getTolerance() const
This function returns the currently set trace cutoff parameter.
Definition: Dyna2.hpp:287
AIToolbox::MDP::QFunction
Matrix2D QFunction
Definition: Types.hpp:52
RandomPolicy.hpp
AIToolbox::MDP::Dyna2::stepUpdateQ
void stepUpdateQ(size_t s, size_t a, size_t s1, size_t a1, double rew)
This function updates the internal QFunction.
Definition: Dyna2.hpp:218
AIToolbox::MDP::Dyna2::getTransientQFunction
const QFunction & getTransientQFunction() const
This function returns a reference to the internal transient QFunction.
Definition: Dyna2.hpp:297
AIToolbox::MDP::BanditPolicyAdaptor
This class extends a Bandit policy so that it can be called from MDP code.
Definition: BanditPolicyAdaptor.hpp:17
AIToolbox::MDP::Dyna2::setN
void setN(unsigned n)
This function sets the current sample number parameter.
AIToolbox::MDP::Dyna2::getPermanentLambda
double getPermanentLambda() const
This function returns the currently set lambda parameter for the permanent SARSAL.
Definition: Dyna2.hpp:309
BanditPolicyAdaptor.hpp
AIToolbox::MDP
Definition: DoubleQLearning.hpp:10
AIToolbox::MDP::Dyna2
This class represents the Dyna2 algorithm.
Definition: Dyna2.hpp:30
AIToolbox::MDP::Dyna2::setPermanentLambda
void setPermanentLambda(double l)
This function sets the new lambda parameter for the permanent SARSAL.
Definition: Dyna2.hpp:307
AIToolbox::MDP::Dyna2::setTransientLambda
void setTransientLambda(double l)
This function sets the new lambda parameter for the transient SARSAL.
Definition: Dyna2.hpp:311
AIToolbox::MDP::Dyna2::getTransientLambda
double getTransientLambda() const
This function returns the currently set lambda parameter for the transient SARSAL.
Definition: Dyna2.hpp:313
AIToolbox::MDP::Dyna2::setTolerance
void setTolerance(double t)
This function sets the trace cutoff parameter.
Definition: Dyna2.hpp:281
AIToolbox::MDP::Dyna2::setInternalPolicy
void setInternalPolicy(PolicyInterface *p)
This function sets the policy used to sample during batch updates.
Definition: Dyna2.hpp:271
AIToolbox::MDP::SARSAL
This class represents the SARSAL algorithm.
Definition: SARSAL.hpp:37
AIToolbox::MDP::Dyna2::getModel
const M & getModel() const
This function returns a reference to the referenced Model.
Definition: Dyna2.hpp:302
Types.hpp
TypeTraits.hpp
SARSAL.hpp
AIToolbox::MDP::PolicyInterface
Simple typedef for most of MDP's policy needs.
Definition: PolicyInterface.hpp:11
AIToolbox::MDP::Dyna2::Dyna2
Dyna2(const M &m, double alpha=0.1, double lambda=0.9, double tolerance=0.001, unsigned n=50)
Basic constructor.
Definition: Dyna2.hpp:209
AIToolbox::MDP::Dyna2::batchUpdateQ
void batchUpdateQ(size_t s)
This function updates a QFunction based on simulated experience.
Definition: Dyna2.hpp:241