1 #ifndef AI_TOOLBOX_POMDP_POMCP_HEADER_FILE
2 #define AI_TOOLBOX_POMDP_POMCP_HEADER_FILE
4 #include <unordered_map>
61 template <IsGenerativeModel M>
92 POMCP(
const M& m,
size_t beliefSize,
unsigned iterations,
double exp);
131 size_t sampleAction(
size_t a,
size_t o,
unsigned horizon);
201 size_t S, A, beliefSize_;
202 unsigned iterations_, maxDepth_;
223 size_t runSimulation(
unsigned horizon);
247 double simulate(
BeliefNode & b,
size_t s,
unsigned horizon);
258 template <
typename Iterator>
259 Iterator findBestA(Iterator begin, Iterator end);
275 template <
typename Iterator>
276 Iterator findBestBonusA(Iterator begin, Iterator end,
unsigned count);
288 template <IsGenerativeModel M>
289 POMCP<M>::POMCP(
const M& m,
const size_t beliefSize,
const unsigned iter,
const double exp) :
290 model_(m), S(model_.getS()), A(model_.getA()), beliefSize_(beliefSize),
291 iterations_(iter), exploration_(exp), graph_(), rand_(
Seeder::getSeed()) {}
293 template <IsGenerativeModel M>
297 graph_.children.resize(A);
298 graph_.belief = makeSampledBelief(b);
300 return runSimulation(horizon);
303 template <IsGenerativeModel M>
305 const auto & obs = graph_.children[a].children;
307 auto it = obs.find(o);
308 if ( it == obs.end() ) {
310 auto b =
Belief(S); b.fill(1.0/S);
311 return sampleAction(b, horizon);
319 {
auto tmp = std::move(it->second); graph_ = std::move(tmp); }
321 if ( ! graph_.belief.size() ) {
323 auto b =
Belief(S); b.fill(1.0/S);
324 return sampleAction(b, horizon);
330 graph_.children.resize(A);
332 return runSimulation(horizon);
335 template <IsGenerativeModel M>
337 if ( !horizon )
return 0;
340 std::uniform_int_distribution<size_t> generator(0, graph_.belief.size()-1);
342 for (
unsigned i = 0; i < iterations_; ++i )
343 simulate(graph_, graph_.belief.at(generator(rand_)), 0);
345 auto begin = std::begin(graph_.children);
346 return std::distance(begin, findBestA(begin, std::end(graph_.children)));
349 template <IsGenerativeModel M>
350 double POMCP<M>::simulate(BeliefNode & b,
const size_t s,
const unsigned depth) {
353 auto begin = std::begin(b.children);
354 const size_t a = std::distance(begin, findBestBonusA(begin, std::end(b.children), b.N));
356 auto [s1, o, rew] = model_.sampleSOR(s, a);
358 auto & aNode = b.children[a];
361 double futureRew = 0.0;
364 auto ot = aNode.children.find(o);
365 if ( ot == std::end(aNode.children) ) {
366 aNode.children.emplace(std::piecewise_construct,
367 std::forward_as_tuple(o),
368 std::forward_as_tuple(s1));
370 futureRew =
rollout(model_, s1, maxDepth_ - depth + 1, rand_);
373 ot->second.belief.push_back(s1);
375 if ( depth + 1 < maxDepth_ && !model_.isTerminal(s1) ) {
381 ot->second.children.resize(A);
382 futureRew = simulate( ot->second, s1, depth + 1 );
386 rew += model_.getDiscount() * futureRew;
391 aNode.V += ( rew - aNode.V ) /
static_cast<double>(aNode.N);
396 template <IsGenerativeModel M>
397 template <
typename Iterator>
398 Iterator POMCP<M>::findBestA(Iterator begin, Iterator end) {
399 return std::max_element(begin, end, [](
const ActionNode & lhs,
const ActionNode & rhs){
return lhs.V < rhs.V; });
402 template <IsGenerativeModel M>
403 template <
typename Iterator>
404 Iterator POMCP<M>::findBestBonusA(Iterator begin, Iterator end,
const unsigned count) {
407 const double logCount = std::log(count + 1.0);
410 auto evaluationFunction = [
this, logCount](
const ActionNode & an){
411 return an.V + exploration_ * std::sqrt( logCount / an.N );
414 auto bestIterator = begin++;
415 double bestValue = evaluationFunction(*bestIterator);
417 for ( ; begin < end; ++begin ) {
418 const double actionValue = evaluationFunction(*begin);
419 if ( actionValue > bestValue ) {
420 bestValue = actionValue;
421 bestIterator = begin;
428 template <IsGenerativeModel M>
431 belief.reserve(beliefSize_);
433 for (
size_t i = 0; i < beliefSize_; ++i )
439 template <IsGenerativeModel M>
441 beliefSize_ = beliefSize;
444 template <IsGenerativeModel M>
449 template <IsGenerativeModel M>
454 template <IsGenerativeModel M>
459 template <IsGenerativeModel M>
464 template <IsGenerativeModel M>
469 template <IsGenerativeModel M>
474 template <IsGenerativeModel M>