AIToolbox
A library that offers tools for AI problem solving.
|
Go to the documentation of this file. 1 #ifndef AI_TOOLBOX_MDP_PRIORITIZED_SWEEPING_HEADER_FILE
2 #define AI_TOOLBOX_MDP_PRIORITIZED_SWEEPING_HEADER_FILE
5 #include <unordered_map>
7 #include <boost/heap/fibonacci_heap.hpp>
100 void setN(
unsigned n);
107 unsigned getN()
const;
157 struct PriorityQueueElement {
159 std::pair<size_t, size_t> stateAction;
160 bool operator<(
const PriorityQueueElement& arg2)
const {
161 return priority < arg2.priority;
165 using QueueType = boost::heap::fibonacci_heap<PriorityQueueElement>;
169 std::unordered_map<std::pair<size_t, size_t>,
typename QueueType::handle_type, boost::hash<std::pair<size_t, size_t>>> queueHandles_;
174 S(m.getS()), A(m.getA()), N(n), theta_(theta), model_(m),
179 auto & values = vfun_.values;
182 if constexpr(IsModelEigen<M>) {
183 qfun_(s,a) = model_.getRewardFunction().coeff(s, a) + model_.getTransitionFunction(a).row(s).dot(values * model_.getDiscount());
185 double newQValue = 0;
186 for (
size_t s1 = 0; s1 < S; ++s1 ) {
187 const double probability = model_.getTransitionProbability(s,a,s1);
189 newQValue += probability * ( model_.getExpectedReward(s,a,s1) + model_.getDiscount() * values[s1] );
191 qfun_(s, a) = newQValue;
194 double p = values[s];
197 values[s] = qfun_.row(s).maxCoeff(&(vfun_.actions[s]));
200 p = std::fabs(values[s] - p);
202 for (
size_t ss = 0; ss < S; ++ss ) {
203 for (
size_t a = 0; a < A; ++a ) {
204 const double delta = p * model_.getTransitionProbability(ss,a,s);
206 if ( delta > theta_ ) {
207 const auto pair = std::make_pair(ss, a);
208 auto it = queueHandles_.find(pair);
210 if (it != std::end(queueHandles_)) {
211 if ((*it->second).priority < delta) {
212 (*it->second).priority = delta;
213 queue_.increase(it->second);
216 queueHandles_[pair] = queue_.emplace(PriorityQueueElement{delta, pair});
225 for (
unsigned i = 0; i < N; ++i ) {
226 if ( queue_.empty() )
return;
230 auto [p, pair] = queue_.top();
234 queueHandles_.erase(pair);
236 stepUpdateQ(pair.first, pair.second);
252 if ( t < 0.0 )
throw std::invalid_argument(
"Theta parameter must be >= 0");
263 return queue_.size();