1 #ifndef AI_TOOLBOX_MDP_MCTS_HEADER_FILE
2 #define AI_TOOLBOX_MDP_MCTS_HEADER_FILE
10 #include <unordered_map>
47 template <
typename M,
template <
typename>
class StateHash = std::hash>
48 requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
50 using State = std::remove_cvref_t<decltype(std::declval<M>().getS())>;
51 static constexpr
bool hashState = !std::is_same_v<size_t, State>;
55 using StateNodes = std::unordered_map<size_t, StateNode>;
76 MCTS(
const M& m,
unsigned iterations,
double exp);
106 size_t sampleAction(
size_t a,
const State & s1,
unsigned horizon);
158 unsigned iterations_, maxDepth_;
166 size_t runSimulation(
const State & s,
unsigned horizon);
167 double simulate(
StateNode & sn,
const State & s,
unsigned horizon);
168 void allocateActionNodes(
ActionNodes & an,
const State & s);
170 template <
typename Iterator>
171 Iterator findBestA(Iterator begin, Iterator end);
173 template <
typename Iterator>
174 Iterator findBestBonusA(Iterator begin, Iterator end,
unsigned count);
177 template <
typename M,
template <
typename>
class StateHash>
178 requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
180 model_(m), iterations_(iter),
181 exploration_(exp), graph_(), rand_(
Seeder::getSeed()) {}
183 template <
typename M,
template <
typename>
class StateHash>
184 requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
189 allocateActionNodes(graph_.children, s);
191 return runSimulation(s, horizon);
194 template <
typename M,
template <
typename>
class StateHash>
195 requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
197 auto & states = graph_.children[a].children;
200 if constexpr (hashState) s1Key = StateHash<State>()(s1);
203 auto it = states.find(s1Key);
204 if ( it == states.end() )
205 return sampleAction(s1, horizon);
212 {
auto tmp = std::move(it->second); graph_ = std::move(tmp); }
217 allocateActionNodes(graph_.children, s1);
219 return runSimulation(s1, horizon);
222 template <
typename M,
template <
typename>
class StateHash>
223 requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
225 if ( !horizon )
return 0;
229 for (
unsigned i = 0; i < iterations_; ++i )
230 simulate(graph_, s, 0);
232 auto begin = std::begin(graph_.children);
233 return std::distance(begin, findBestA(begin, std::end(graph_.children)));
236 template <
typename M,
template <
typename>
class StateHash>
237 requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
238 double MCTS<M, StateHash>::simulate(StateNode & sn,
const State & s,
const unsigned depth) {
242 auto begin = std::begin(sn.children);
243 const size_t a = std::distance(begin, findBestBonusA(begin, std::end(sn.children), sn.N));
245 auto [s1, rew] = model_.sampleSR(s, a);
247 auto & aNode = sn.children[a];
250 if ( depth + 1 < maxDepth_ && !model_.isTerminal(s1) ) {
262 if constexpr (hashState) s1Key = StateHash<State>()(s1);
265 auto it = aNode.children.find(s1Key);
268 if ( it == std::end(aNode.children) ) {
270 aNode.children[s1Key];
271 futureRew =
rollout(model_, s1, maxDepth_ - depth + 1, rand_);
279 allocateActionNodes(it->second.children, s1);
280 futureRew = simulate( it->second, s1, depth + 1 );
283 rew += model_.getDiscount() * futureRew;
288 aNode.V += ( rew - aNode.V ) /
static_cast<double>(aNode.N);
293 template <
typename M,
template <
typename>
class StateHash>
294 requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
295 template <
typename Iterator>
297 return std::max_element(begin, end, [](
const ActionNode & lhs,
const ActionNode & rhs){
return lhs.
V < rhs.
V; });
300 template <
typename M,
template <
typename>
class StateHash>
301 requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
302 template <
typename Iterator>
306 const double logCount = std::log(count + 1.0);
309 const auto evaluationFunction = [
this, logCount](
const ActionNode & an){
310 return an.V + exploration_ * std::sqrt( logCount / an.N );
313 auto bestIterator = begin++;
314 double bestValue = evaluationFunction(*bestIterator);
316 for ( ; begin < end; ++begin ) {
317 double actionValue = evaluationFunction(*begin);
318 if ( actionValue > bestValue ) {
319 bestValue = actionValue;
320 bestIterator = begin;
327 template <
typename M,
template <
typename>
class StateHash>
328 requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
330 if constexpr (HasFixedActionSpace<M>)
331 an.resize(model_.getA());
333 an.resize(model_.getA(s));
336 template <
typename M,
template <
typename>
class StateHash>
337 requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
342 template <
typename M,
template <
typename>
class StateHash>
343 requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
348 template <
typename M,
template <
typename>
class StateHash>
349 requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
354 template <
typename M,
template <
typename>
class StateHash>
355 requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
360 template <
typename M,
template <
typename>
class StateHash>
361 requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>
366 template <
typename M,
template <
typename>
class StateHash>
367 requires AIToolbox::IsGenerativeModel<M> && HasIntegralActionSpace<M>