1 #ifndef AI_TOOLBOX_POMDP_BELIEF_GENERATOR_HEADER_FILE 
    2 #define AI_TOOLBOX_POMDP_BELIEF_GENERATOR_HEADER_FILE 
    6 #include <boost/container/flat_set.hpp> 
   17     template <IsGenerativeModel M>
 
   61             using SeenObservations = std::vector<boost::container::flat_set<std::pair<size_t, size_t>>>;
 
   72             void expandBeliefList(
size_t max, 
size_t randomBeliefsToAdd, 
size_t firstProductiveBelief) 
const;
 
   80             static constexpr 
unsigned triesPerRun_ = 20;
 
   81             static constexpr 
unsigned retryLimit_ = 5;
 
   82             static constexpr 
unsigned minProductiveBeliefs_ = 10;
 
   84             mutable SeenObservations * sop_;
 
   85             mutable std::vector<unsigned> * up_; 
 
   86             mutable std::vector<double> * dp_; 
 
   87             mutable size_t goodBeliefsSize_, allBeliefsSize_, productiveBeliefs_;
 
   92     template <IsGenerativeModel M>
 
   94             model_(model), S(model_.getS()), A(model_.getA()),
 
   96             blp_(nullptr), sop_(nullptr), up_(nullptr), dp_(nullptr), helper_(S) {}
 
   98     template <IsGenerativeModel M>
 
  101         BeliefList beliefs; beliefs.reserve(std::max(beliefNumber, S));
 
  103         beliefs.emplace_back(S);
 
  104         beliefs.back().fill(1.0/S);
 
  106         for ( 
size_t s = 0; s < S && s < beliefNumber; ++s ) {
 
  107             beliefs.emplace_back(S);
 
  108             beliefs.back().setZero(); beliefs.back()(s) = 1.0;
 
  111         this->operator()(beliefNumber, &beliefs);
 
  116     template <IsGenerativeModel M>
 
  141         auto & beliefs = *blp_;
 
  143         SeenObservations seenObservations;
 
  144         seenObservations.resize(beliefs.size());
 
  146         std::vector<unsigned> unproductiveBeliefs;
 
  147         unproductiveBeliefs.resize(beliefs.size());
 
  149         sop_ = &seenObservations;
 
  150         up_ = &unproductiveBeliefs;
 
  152         beliefs.reserve(maxBeliefs);
 
  153         seenObservations.reserve(maxBeliefs);
 
  154         unproductiveBeliefs.reserve(maxBeliefs);
 
  156         std::vector<double> distances;
 
  164         size_t firstProductiveBelief = 0;
 
  165         productiveBeliefs_ = goodBeliefsSize_ = allBeliefsSize_ = beliefs.size();
 
  167         unsigned randomBeliefsToAdd = 0;
 
  169         while ( goodBeliefsSize_ < maxBeliefs ) {
 
  170             expandBeliefList(maxBeliefs, randomBeliefsToAdd, firstProductiveBelief);
 
  171             if (goodBeliefsSize_ >= maxBeliefs) 
break;
 
  175             for (
size_t i = firstProductiveBelief; i < goodBeliefsSize_; ++i) {
 
  176                 if (unproductiveBeliefs[i] < retryLimit_) 
break;
 
  177                 else ++firstProductiveBelief;
 
  181             randomBeliefsToAdd = productiveBeliefs_ >= minProductiveBeliefs_ ? 0 : minProductiveBeliefs_ - productiveBeliefs_;
 
  184         beliefs.resize(maxBeliefs);
 
  187     template <IsGenerativeModel M>
 
  190         auto & seenObservations = *sop_;
 
  191         auto & unproductiveBeliefs = *up_;
 
  192         auto & distances = *dp_;
 
  199         auto beliefsToAdd = std::max(randomBeliefsToAdd, productiveBeliefs_);
 
  202         auto computeDistance = [](
const Belief & lhs, 
const Belief & rhs) {
 
  203             return (lhs - rhs).cwiseAbs().sum();
 
  208         if (allBeliefsSize_ < bl.size())
 
  212         for ( 
size_t i = 0; i < randomBeliefsToAdd; ++i) {
 
  217             distances.push_back(std::numeric_limits<double>::max());
 
  218             for (
size_t k = 0; k < goodBeliefsSize_; ++k) {
 
  219                 distances.back() = std::min(distances.back(), computeDistance(bl.back(), bl[k]));
 
  226         for (
size_t i = firstProductiveBelief; i < goodBeliefsSize_; ++i) {
 
  228             auto & notFoundCounter = unproductiveBeliefs[i];
 
  229             if (notFoundCounter >= retryLimit_) 
continue;
 
  231             auto & beliefObservations = seenObservations[i];
 
  232             bool foundAnything = 
false;
 
  235             for ( 
size_t a = 0; a < A; ++a ) {
 
  238                 for (
unsigned j = 0; j < triesPerRun_; ++j) {
 
  244                     std::tie(std::ignore, o, std::ignore) = model_.sampleSOR(s, a);
 
  249                     if (beliefObservations.find({a,o}) != beliefObservations.end())
 
  252                     beliefObservations.insert({a,o});
 
  253                     foundAnything = 
true;
 
  258                     if (allBeliefsSize_ == bl.size())
 
  266                     for (
size_t k = 0; k < allBeliefsSize_; ++k) {
 
  281                     distances.push_back(std::numeric_limits<double>::max());
 
  282                     for (
size_t k = 0; k < goodBeliefsSize_; ++k) {
 
  283                         distances.back() = std::min(distances.back(), computeDistance(bl.back(), bl[k]));
 
  289             if (!foundAnything) {
 
  292                 if (notFoundCounter == retryLimit_)
 
  293                     --productiveBeliefs_;
 
  300         beliefsToAdd = std::min(beliefsToAdd, allBeliefsSize_ - goodBeliefsSize_);
 
  302         for (
size_t i = 0; i < beliefsToAdd; ++i) {
 
  303             assert((allBeliefsSize_ - goodBeliefsSize_) == distances.size());
 
  310             auto dBegin = std::begin(distances), dEnd = std::end(distances);
 
  311             size_t id = std::distance( dBegin, std::max_element(dBegin, dEnd) );
 
  314             std::swap(distances[
id], distances.back());
 
  315             std::swap(bl[goodBeliefsSize_ + 
id], bl[allBeliefsSize_ - 1]);
 
  320             std::swap(bl[goodBeliefsSize_], bl[allBeliefsSize_ - 1]);
 
  324             if (goodBeliefsSize_ >= max) 
break;
 
  334             distances.pop_back();
 
  335             seenObservations.emplace_back();
 
  336             unproductiveBeliefs.emplace_back();
 
  337             ++productiveBeliefs_;
 
  339             for (
size_t k = 0; k < distances.size(); ++k) {
 
  340                 distances[k] = std::min(distances[k], computeDistance(bl[goodBeliefsSize_ - 1], bl[goodBeliefsSize_ + k]));