AIToolbox
A library that offers tools for AI problem solving.
|
Go to the documentation of this file. 1 #ifndef AI_TOOLBOX_MDP_OFF_POLICY_TEMPLATE_HEADER_FILE
2 #define AI_TOOLBOX_MDP_OFF_POLICY_TEMPLATE_HEADER_FILE
13 using Trace = std::tuple<size_t, size_t, double>;
25 OffPolicyBase(
size_t s,
size_t a,
double discount = 1.0,
double alpha = 0.1,
double tolerance = 0.001);
169 void updateTraces(
size_t s,
size_t a,
double error,
double traceDiscount);
214 template <
typename Derived>
228 double alpha = 0.1,
double tolerance = 0.001);
242 void stepUpdateQ(
const size_t s,
const size_t a,
const size_t s1,
const double rew);
293 template <
typename Derived>
308 OffPolicyControl(
size_t s,
size_t a,
double discount = 1.0,
double alpha = 0.1,
309 double tolerance = 0.001,
double epsilon = 0.1);
323 void stepUpdateQ(
const size_t s,
const size_t a,
const size_t s1,
const double rew);
352 template <
typename Derived>
354 auto expectedQ = 0.0;
355 for (
size_t a = 0; a < A; ++a)
356 expectedQ += q_(s1, a) * target_.getActionProbability(s1, a);
358 const auto error = alpha_ * ( rew + discount_ * expectedQ - q_(s, a) );
359 const auto traceDiscount = discount_ *
static_cast<Derived*
>(
this)->getTraceDiscount(s, a, s1, rew);
361 updateTraces(s, a, error, traceDiscount);
364 template <
typename Derived>
374 double expectedQ = 0.0;
375 double maxV = std::numeric_limits<double>::lowest();
376 for (
size_t aa = 0; aa < A; ++aa) {
377 expectedQ += q_(s1, aa);
378 if (maxV < q_(s1, aa)) {
383 expectedQ *= epsilon_ / A;
384 expectedQ += (1.0 - epsilon_) * maxV;
386 const auto error = alpha_ * ( rew + discount_ * expectedQ - q_(s, a) );
387 const auto traceDiscount = discount_ *
static_cast<Derived*
>(
this)->getTraceDiscount(s, a, s1, rew, maxA);
389 updateTraces(s, a, error, traceDiscount);
392 template <
typename Derived>
395 const double discount,
const double alpha,
const double tolerance
397 Parent(target.getS(), target.getA(), discount, alpha, tolerance),
400 template <
typename Derived>
402 const size_t s,
const size_t a,
const double discount,
403 const double alpha,
const double tolerance,
const double epsilon
405 Parent(s, a, discount, alpha, tolerance)
410 template <
typename Derived>
412 if ( e < 0.0 || e > 1.0 )
throw std::invalid_argument(
"Epsilon must be >= 0 and <= 1");
416 template <
typename Derived>