1 #ifndef AI_TOOLBOX_POMDP_MODEL_HEADER_FILE
2 #define AI_TOOLBOX_POMDP_MODEL_HEADER_FILE
14 template <MDP::IsModel M>
18 template <MDP::IsModel M>
67 template <MDP::IsModel M>
68 class Model :
public M {
82 template <
typename... Args>
83 Model(
size_t o, Args&&... parameters);
114 Model(
size_t o, ObFun && of, Args&&... parameters);
131 template <
typename PM>
132 requires IsModel<PM> && std::constructible_from<M, PM>
133 Model(
const PM& model);
150 template <
typename... Args>
170 template <IsNaive3DMatrix ObFun>
210 std::tuple<size_t,size_t, double>
sampleSOR(
size_t s,
size_t a)
const;
228 std::tuple<size_t, double>
sampleOR(
size_t s,
size_t a,
size_t s1)
const;
272 template <MDP::IsModel M>
273 template <
typename... Args>
275 M(std::forward<Args>(params)...), O(o),
276 observations_(this->getA(),
Matrix2D(this->getS(), O)), rand_(
Seeder::getSeed())
278 for (
size_t a = 0; a < this->getA(); ++a ) {
279 observations_[a].rightCols(O-1).setZero();
280 observations_[a].col(0).fill(1.0);
284 template <MDP::IsModel M>
287 M(std::forward<Args>(params)...), O(o),
288 observations_(this->getA(),
Matrix2D(this->getS(), O)), rand_(
Seeder::getSeed())
293 template <MDP::IsModel M>
294 template <
typename... Args>
296 M(std::forward<Args>(params)...), O(o),
297 observations_(std::move(ot))
300 template <MDP::IsModel M>
301 template <
typename PM>
302 requires IsModel<PM> && std::constructible_from<M, PM>
304 M(model), O(model.getO()), observations_(this->getA(),
Matrix2D(this->getS(), O)),
307 for (
size_t a = 0; a < this->getA(); ++a )
308 for (
size_t s1 = 0; s1 < this->getS(); ++s1 ) {
309 for (
size_t o = 0; o < O; ++o ) {
310 observations_[a](s1, o) = model.getObservationProbability(s1, a, o);
313 throw std::invalid_argument(
"Input observation matrix does not contain valid probabilities.");
317 template <MDP::IsModel M>
318 template <IsNaive3DMatrix ObFun>
320 for (
size_t s1 = 0; s1 < this->getS(); ++s1 )
321 for (
size_t a = 0; a < this->getA(); ++a )
323 throw std::invalid_argument(
"Input observation matrix does not contain valid probabilities.");
325 for (
size_t s1 = 0; s1 < this->getS(); ++s1 )
326 for (
size_t a = 0; a < this->getA(); ++a )
327 for (
size_t o = 0; o < O; ++o )
328 observations_[a](s1, o) = of[s1][a][o];
331 template <MDP::IsModel M>
334 throw std::invalid_argument(
"Input observation matrix does not contain valid probabilities.");
339 template <MDP::IsModel M>
341 return observations_[a](s1, o);
344 template <MDP::IsModel M>
346 return observations_[a];
349 template <MDP::IsModel M>
354 template <MDP::IsModel M>
356 return observations_;
359 template <MDP::IsModel M>
361 const auto [s1, r] = this->sampleSR(s, a);
363 return std::make_tuple(s1, o, r);
366 template <MDP::IsModel M>
367 std::tuple<size_t, double>
Model<M>::sampleOR(
const size_t s,
const size_t a,
const size_t s1)
const {
369 const double r = this->getExpectedReward(s, a, s1);
370 return std::make_tuple(o, r);