AIToolbox
A library that offers tools for AI problem solving.
Rollout.hpp
Go to the documentation of this file.
1 #ifndef AI_TOOLBOX_MDP_ROLLOUT_HEADER_FILE
2 #define AI_TOOLBOX_MDP_ROLLOUT_HEADER_FILE
3 
6 
7 namespace AIToolbox::MDP {
30  template <typename M, typename Gen>
31  requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
32  double rollout(const M & m, std::remove_cvref_t<decltype(std::declval<M>().getS())> s, const unsigned maxDepth, Gen & rnd) {
33  double rew = 0.0, totalRew = 0.0, gamma = 1.0;
34 
35  // Here we have two separate branches depending on whether the model
36  // provides a variable number of actions or not. Note that they are
37  // nearly identical, and the only difference is how we instantiate the
38  // distribution from which to randomly sample actions.
39 
40  if constexpr (HasFixedActionSpace<M>) {
41  // If we don't have variable actions, we instantiate the
42  // distribution outside of the loop for a slight performance
43  // increase.
44  std::uniform_int_distribution<size_t> dist(0, m.getA()-1);
45  for (unsigned depth = 0; depth < maxDepth; ++depth ) {
46 
47  std::tie( s, rew ) = m.sampleSR( s, dist(rnd) );
48  totalRew += gamma * rew;
49 
50  if (m.isTerminal(s))
51  return totalRew;
52 
53  gamma *= m.getDiscount();
54  }
55  } else {
56  // Otherwise, we need to poll the model at every timestep to check
57  // the allowed number of actions, and sample from those.
58  for (unsigned depth = 0; depth < maxDepth; ++depth ) {
59  std::uniform_int_distribution<size_t> dist(0, m.getA(s)-1);
60 
61  std::tie( s, rew ) = m.sampleSR( s, dist(rnd) );
62  totalRew += gamma * rew;
63 
64  if (m.isTerminal(s))
65  return totalRew;
66 
67  gamma *= m.getDiscount();
68  }
69  }
70  return totalRew;
71  }
72 }
73 
74 #endif
AIToolbox::MDP
Definition: DoubleQLearning.hpp:10
AIToolbox::MDP::rollout
requires AIToolbox::IsGenerativeModel< M > &&HasIntegralActionSpace< M > double rollout(const M &m, std::remove_cvref_t< decltype(std::declval< M >().getS())> s, const unsigned maxDepth, Gen &rnd)
This function performs a rollout from the input state.
Definition: Rollout.hpp:32
TypeTraits.hpp
Types.hpp