AIToolbox
A library that offers tools for AI problem solving.
MCTS.hpp
Go to the documentation of this file.
1 #ifndef AI_TOOLBOX_MDP_MCTS_HEADER_FILE
2 #define AI_TOOLBOX_MDP_MCTS_HEADER_FILE
3 
7 #include <AIToolbox/Seeder.hpp>
9 
10 #include <unordered_map>
11 
12 namespace AIToolbox::MDP {
47  template <typename M, template <typename> class StateHash = std::hash>
48  requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
49  class MCTS {
50  using State = std::remove_cvref_t<decltype(std::declval<M>().getS())>;
51  static constexpr bool hashState = !std::is_same_v<size_t, State>;
52 
53  public:
54  struct StateNode;
55  using StateNodes = std::unordered_map<size_t, StateNode>;
56 
57  struct ActionNode {
59  double V = 0.0;
60  unsigned N = 0;
61  };
62  using ActionNodes = std::vector<ActionNode>;
63 
64  struct StateNode {
66  unsigned N = 0;
67  };
68 
76  MCTS(const M& m, unsigned iterations, double exp);
77 
86  size_t sampleAction(const State & s, unsigned horizon);
87 
106  size_t sampleAction(size_t a, const State & s1, unsigned horizon);
107 
113  void setIterations(unsigned iter);
114 
126  void setExploration(double exp);
127 
133  const M& getModel() const;
134 
140  const StateNode& getGraph() const;
141 
147  unsigned getIterations() const;
148 
154  double getExploration() const;
155 
156  private:
157  const M& model_;
158  unsigned iterations_, maxDepth_;
159  double exploration_;
160 
161  StateNode graph_;
162 
163  mutable RandomEngine rand_;
164 
165  // Private Methods
166  size_t runSimulation(const State & s, unsigned horizon);
167  double simulate(StateNode & sn, const State & s, unsigned horizon);
168  void allocateActionNodes(ActionNodes & an, const State & s);
169 
170  template <typename Iterator>
171  Iterator findBestA(Iterator begin, Iterator end);
172 
173  template <typename Iterator>
174  Iterator findBestBonusA(Iterator begin, Iterator end, unsigned count);
175  };
176 
177  template <typename M, template <typename> class StateHash>
178  requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
179  MCTS<M, StateHash>::MCTS(const M& m, const unsigned iter, const double exp) :
180  model_(m), iterations_(iter),
181  exploration_(exp), graph_(), rand_(Seeder::getSeed()) {}
182 
183  template <typename M, template <typename> class StateHash>
184  requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
185  size_t MCTS<M, StateHash>::sampleAction(const State & s, const unsigned horizon) {
186  // Reset graph
187  graph_ = StateNode();
188 
189  allocateActionNodes(graph_.children, s);
190 
191  return runSimulation(s, horizon);
192  }
193 
194  template <typename M, template <typename> class StateHash>
195  requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
196  size_t MCTS<M, StateHash>::sampleAction(const size_t a, const State & s1, const unsigned horizon) {
197  auto & states = graph_.children[a].children;
198 
199  size_t s1Key;
200  if constexpr (hashState) s1Key = StateHash<State>()(s1);
201  else s1Key = s1;
202 
203  auto it = states.find(s1Key);
204  if ( it == states.end() )
205  return sampleAction(s1, horizon);
206 
207  // Here we need an additional step, because *it is contained by graph_.
208  // If we just move assign, graph_ is first going to delete everything it
209  // contains (included *it), and then we are going to move unallocated memory
210  // into graph_! So we move *it outside of the graph_ hierarchy, so that
211  // we can then assign safely.
212  { auto tmp = std::move(it->second); graph_ = std::move(tmp); }
213 
214  // We resize here in case we didn't have time to sample the new
215  // head node. In this case, the new head may not have children.
216  // This would break the UCT call.
217  allocateActionNodes(graph_.children, s1);
218 
219  return runSimulation(s1, horizon);
220  }
221 
222  template <typename M, template <typename> class StateHash>
223  requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
224  size_t MCTS<M, StateHash>::runSimulation(const State & s, const unsigned horizon) {
225  if ( !horizon ) return 0;
226 
227  maxDepth_ = horizon;
228 
229  for (unsigned i = 0; i < iterations_; ++i )
230  simulate(graph_, s, 0);
231 
232  auto begin = std::begin(graph_.children);
233  return std::distance(begin, findBestA(begin, std::end(graph_.children)));
234  }
235 
236  template <typename M, template <typename> class StateHash>
237  requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
238  double MCTS<M, StateHash>::simulate(StateNode & sn, const State & s, const unsigned depth) {
239  // Head update
240  sn.N++;
241 
242  auto begin = std::begin(sn.children);
243  const size_t a = std::distance(begin, findBestBonusA(begin, std::end(sn.children), sn.N));
244 
245  auto [s1, rew] = model_.sampleSR(s, a);
246 
247  auto & aNode = sn.children[a];
248 
249  // We only go deeper if needed (maxDepth_ is always at least 1).
250  if ( depth + 1 < maxDepth_ && !model_.isTerminal(s1) ) {
251  // If our state is not a size_t, hash it so we can work with the
252  // StateNode map. The reason to hash it ourselves is that the map
253  // *will* store the keys, and so if the state is an expensive
254  // object (like a vector), we will have tons of allocations which
255  // we can avoid, since we don't need to remember the exact state here.
256  //
257  // This *could* go wrong if two reachable states hash to the same
258  // thing, since in this way we won't be able to distinguish them
259  // (while a full-fledged map can), but this should be extremely
260  // improbable and worth the performance gain.
261  size_t s1Key;
262  if constexpr (hashState) s1Key = StateHash<State>()(s1);
263  else s1Key = s1;
264 
265  auto it = aNode.children.find(s1Key);
266 
267  double futureRew;
268  if ( it == std::end(aNode.children) ) {
269  // Touch node to create it
270  aNode.children[s1Key];
271  futureRew = rollout(model_, s1, maxDepth_ - depth + 1, rand_);
272  }
273  else {
274  // Since most memory is allocated on the leaves,
275  // we do not allocate on node creation but only when
276  // we are actually descending into a node. If the node
277  // already has memory this should not do anything in
278  // any case.
279  allocateActionNodes(it->second.children, s1);
280  futureRew = simulate( it->second, s1, depth + 1 );
281  }
282 
283  rew += model_.getDiscount() * futureRew;
284  }
285 
286  // Action update
287  aNode.N++;
288  aNode.V += ( rew - aNode.V ) / static_cast<double>(aNode.N);
289 
290  return rew;
291  }
292 
293  template <typename M, template <typename> class StateHash>
294  requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
295  template <typename Iterator>
296  Iterator MCTS<M, StateHash>::findBestA(Iterator begin, Iterator end) {
297  return std::max_element(begin, end, [](const ActionNode & lhs, const ActionNode & rhs){ return lhs.V < rhs.V; });
298  }
299 
300  template <typename M, template <typename> class StateHash>
301  requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
302  template <typename Iterator>
303  Iterator MCTS<M, StateHash>::findBestBonusA(Iterator begin, Iterator end, const unsigned count) {
304  // Count here can be as low as 1.
305  // Since log(1) = 0, and 0/0 = error, we add 1.0.
306  const double logCount = std::log(count + 1.0);
307  // We use this function to produce a score for each action. This can be easily
308  // substituted with something else to produce different POMCP variants.
309  const auto evaluationFunction = [this, logCount](const ActionNode & an){
310  return an.V + exploration_ * std::sqrt( logCount / an.N );
311  };
312 
313  auto bestIterator = begin++;
314  double bestValue = evaluationFunction(*bestIterator);
315 
316  for ( ; begin < end; ++begin ) {
317  double actionValue = evaluationFunction(*begin);
318  if ( actionValue > bestValue ) {
319  bestValue = actionValue;
320  bestIterator = begin;
321  }
322  }
323 
324  return bestIterator;
325  }
326 
327  template <typename M, template <typename> class StateHash>
328  requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
330  if constexpr (HasFixedActionSpace<M>)
331  an.resize(model_.getA());
332  else
333  an.resize(model_.getA(s));
334  }
335 
336  template <typename M, template <typename> class StateHash>
337  requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
338  void MCTS<M, StateHash>::setIterations(const unsigned iter) {
339  iterations_ = iter;
340  }
341 
342  template <typename M, template <typename> class StateHash>
343  requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
344  void MCTS<M, StateHash>::setExploration(const double exp) {
345  exploration_ = exp;
346  }
347 
348  template <typename M, template <typename> class StateHash>
349  requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
350  const M& MCTS<M, StateHash>::getModel() const {
351  return model_;
352  }
353 
354  template <typename M, template <typename> class StateHash>
355  requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
357  return graph_;
358  }
359 
360  template <typename M, template <typename> class StateHash>
361  requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
363  return iterations_;
364  }
365 
366  template <typename M, template <typename> class StateHash>
367  requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
369  return exploration_;
370  }
371 }
372 
373 #endif
AIToolbox::Seeder
This class is an internal class used to seed all random engines in the library.
Definition: Seeder.hpp:15
AIToolbox::POMDP::TigerProblemUtils::State
State
Definition: TigerProblem.hpp:15
AIToolbox::MDP::MCTS::StateNode::N
unsigned N
Definition: MCTS.hpp:66
AIToolbox::MDP::MCTS
This class represents the MCTS online planner using UCB1.
Definition: MCTS.hpp:49
AIToolbox::MDP::MCTS::MCTS
MCTS(const M &m, unsigned iterations, double exp)
Basic constructor.
Definition: MCTS.hpp:179
AIToolbox::MDP::MCTS::getIterations
unsigned getIterations() const
This function returns the number of iterations performed to plan for an action.
Definition: MCTS.hpp:362
Rollout.hpp
AIToolbox::MDP::MCTS::ActionNode
Definition: MCTS.hpp:57
AIToolbox::MDP
Definition: DoubleQLearning.hpp:10
AIToolbox::MDP::rollout
requires AIToolbox::IsGenerativeModel< M > &&HasIntegralActionSpace< M > double rollout(const M &m, std::remove_cvref_t< decltype(std::declval< M >().getS())> s, const unsigned maxDepth, Gen &rnd)
This function performs a rollout from the input state.
Definition: Rollout.hpp:32
AIToolbox::MDP::MCTS::ActionNodes
std::vector< ActionNode > ActionNodes
Definition: MCTS.hpp:62
AIToolbox::MDP::MCTS::getModel
const M & getModel() const
This function returns the MDP generative model being used.
Definition: MCTS.hpp:350
Seeder.hpp
AIToolbox::RandomEngine
std::mt19937 RandomEngine
Definition: Types.hpp:14
AIToolbox::MDP::MCTS::StateNode::children
ActionNodes children
Definition: MCTS.hpp:65
AIToolbox::MDP::MCTS::setIterations
void setIterations(unsigned iter)
This function sets the number of performed rollouts in MCTS.
Definition: MCTS.hpp:338
AIToolbox::MDP::MCTS::StateNode
Definition: MCTS.hpp:64
AIToolbox::MDP::MCTS::getExploration
double getExploration() const
This function returns the currently set exploration constant.
Definition: MCTS.hpp:368
AIToolbox::MDP::MCTS::ActionNode::children
StateNodes children
Definition: MCTS.hpp:58
AIToolbox::MDP::MCTS::getGraph
const StateNode & getGraph() const
This function returns a reference to the internal graph structure holding the results of rollouts.
Definition: MCTS.hpp:356
AIToolbox::POMDP::ActionNodes
std::vector< ActionNode< UseEntropy > > ActionNodes
Definition: rPOMCPGraph.hpp:27
AIToolbox::MDP::MCTS::sampleAction
size_t sampleAction(const State &s, unsigned horizon)
This function resets the internal graph and samples for the provided state and horizon.
Definition: MCTS.hpp:185
Types.hpp
AIToolbox::MDP::MCTS::setExploration
void setExploration(double exp)
This function sets the new exploration constant for MCTS.
Definition: MCTS.hpp:344
TypeTraits.hpp
AIToolbox::MDP::MCTS::StateNodes
std::unordered_map< size_t, StateNode > StateNodes
Definition: MCTS.hpp:55
AIToolbox::MDP::MCTS::ActionNode::N
unsigned N
Definition: MCTS.hpp:60
AIToolbox::MDP::MCTS::ActionNode::V
double V
Definition: MCTS.hpp:59
Probability.hpp