AIToolbox
A library that offers tools for AI problem solving.
GenericVariableElimination.hpp
Go to the documentation of this file.
1 #ifndef AI_TOOLBOX_FACTORED_GENERIC_VARIABLE_ELIMINATION_HEADER_FILE
2 #define AI_TOOLBOX_FACTORED_GENERIC_VARIABLE_ELIMINATION_HEADER_FILE
3 
6 
7 #include <AIToolbox/Logging.hpp>
9 
10 namespace AIToolbox::Factored {
71  template <typename Factor>
73  public:
74  using Rule = std::pair<size_t, Factor>;
75  using Rules = std::vector<Rule>;
77  using FinalFactors = std::vector<Factor>;
78 
86  template <typename Global>
87  void operator()(const Factors & V, Graph & graph, Global & global);
88 
89  private:
95  template <typename M>
96  struct global_interface;
97 
107  template <typename Global>
108  void removeFactor(const Factors & V, Graph & graph, const size_t v, FinalFactors & finalFactors, Global & global);
109  };
110 
111  template <typename Factor>
112  template <typename M>
113  struct GenericVariableElimination<Factor>::global_interface {
114  #define STR2(X) #X
115  #define STR(X) STR2(X)
116  #define ARG(...) __VA_ARGS__
117 
118  // For each function we want to check, we are going to try each
119  // overload in succession (char->int->long->...).
120  //
121  // The first two simply accept the function with the approved
122  // signature, whether it is const or not. The third checks whether
123  // the member just exists, and reports that it probably has the
124  // wrong signature (since we didn't match before).
125  //
126  // The last just fails to find the match.
127  #define MEMBER_CHECK(name, retval, input) \
128  \
129  private: \
130  \
131  template <typename Z> static constexpr auto name##Check(char) -> decltype( \
132  static_cast<retval (Z::*)(input)> (&Z::name), \
133  bool() \
134  ) { return true; } \
135  template <typename Z> static constexpr auto name##Check(int) -> decltype( \
136  static_cast<retval (Z::*)(input) const> (&Z::name), \
137  bool() \
138  ) { return true; } \
139  template <typename Z> static constexpr auto name##Check(long) -> decltype( \
140  &Z::name, \
141  bool()) \
142  { \
143  static_assert(Impl::is_compatible_f< \
144  decltype(&Z::name), \
145  retval(input) \
146  >::value, "You provide a member '" STR(name) "' but with the wrong signature."); \
147  return true; \
148  } \
149  template <typename Z> static constexpr auto name##Check(...) -> bool { return false; } \
150  \
151  public: \
152  enum { \
153  name = name##Check<M>('\0') \
154  };
155 
156  MEMBER_CHECK(beginRemoval, void, ARG(const Graph &, const typename Graph::FactorItList &, const typename Graph::Variables &, size_t))
157  MEMBER_CHECK(initNewFactor, void, void)
158  MEMBER_CHECK(beginCrossSum, void, size_t)
159  MEMBER_CHECK(beginFactorCrossSum, void, void)
160  MEMBER_CHECK(crossSum, void, const Factor &)
161  MEMBER_CHECK(endFactorCrossSum, void, void)
162  MEMBER_CHECK(endCrossSum, void, void)
163  MEMBER_CHECK(isValidNewFactor, bool, void)
164  MEMBER_CHECK(mergeFactors, void, ARG(Factor &, Factor &&))
165  MEMBER_CHECK(makeResult, void, FinalFactors &&)
166 
167  #undef MEMBER_CHECK
168  #undef ARG
169  #undef STR
170  #undef STR2
171  };
172 
173  template <typename Factor>
174  template <typename Global>
175  void GenericVariableElimination<Factor>::operator()(const Factors & V, Graph & graph, Global & global) {
176  static_assert(global_interface<Global>::crossSum, "You must provide a crossSum method!");
177  static_assert(global_interface<Global>::makeResult, "You must provide a makeResult method!");
178  static_assert(std::is_same_v<Factor, decltype(global.newFactor)>, "You must provide a public 'Factor newFactor;' member!");
179 
180  FinalFactors finalFactors;
181 
182  // We remove variables one at a time from the graph, storing the last
183  // remaining nodes in the finalFactors variable.
184  while (graph.variableSize())
185  removeFactor(V, graph, graph.bestVariableToRemove(V), finalFactors, global);
186 
187  global.makeResult(std::move(finalFactors));
188  }
189 
190  template <typename Factor>
191  template <typename Global>
192  void GenericVariableElimination<Factor>::removeFactor(const Factors & V, Graph & graph, const size_t v, FinalFactors & finalFactors, Global & global) {
193  AI_LOGGER(AI_SEVERITY_INFO, "Removing variable " << v);
194 
195  // We iterate over all possible joint values of the neighbors of 'f';
196  // these are all variables which share at least one factor with it.
197  const auto & factors = graph.getFactors(v);
198  const auto & vNeighbors = graph.getVariables(v);
199 
200  if constexpr(global_interface<Global>::beginRemoval)
201  Impl::callFunction(global, &Global::beginRemoval, graph, factors, vNeighbors, v);
202 
203  // We'll now create new rules that represent the elimination of the
204  // input variable for this round.
205  const bool isFinalFactor = vNeighbors.size() == 0;
206 
207  Rules * oldRulesP;
208  size_t oldRulesCurrId = 0;
209 
210  PartialFactorsEnumerator jointValues(V, vNeighbors, v, true);
211  const auto id = jointValues.getFactorToSkipId();
212 
213  if (!isFinalFactor) {
214  oldRulesP = &graph.getFactor(vNeighbors)->getData();
215  oldRulesP->reserve(jointValues.size());
216  }
217 
218  AI_LOGGER(
220  "Width of this factor: " << vNeighbors.size() + 1 << ". "
221  "Joint values to iterate: " << jointValues.size() * V[v]
222  );
223 
224  size_t jvID = 0;
225  while (jointValues.isValid()) {
226  auto & jointValue = *jointValues;
227 
228  if constexpr(global_interface<Global>::initNewFactor)
229  global.initNewFactor();
230 
231  // Since we are eliminating 'v', we iterate over its possible
232  // values and we reduce over them; this could be a cross-sum
233  // operation, a max, or anything else.
234  for (size_t vValue = 0; vValue < V[v]; ++vValue) {
235  if constexpr(global_interface<Global>::beginCrossSum)
236  Impl::callFunction(global, &Global::beginCrossSum, vValue);
237 
238  jointValue.second[id] = vValue;
239  for (const auto factor : factors) {
240  if constexpr(global_interface<Global>::beginFactorCrossSum)
241  global.beginFactorCrossSum();
242 
243  // We reduce over each Factor that is applicable to this
244  // particular joint value set.
245  const size_t jvPartialIndex = toIndexPartial(factor->getVariables(), V, jointValue);
246  if constexpr(global_interface<Global>::mergeFactors) {
247  const auto & data = factor->getData();
248  const auto ruleIt = std::lower_bound(
249  std::begin(data),
250  std::end(data),
251  jvPartialIndex,
252  [](const Rule & lhs, const size_t rhs) {
253  return lhs.first < rhs;
254  }
255  );
256  if (ruleIt != std::end(data) && ruleIt->first == jvPartialIndex)
257  global.crossSum(ruleIt->second);
258  } else {
259  for (const auto & rule : factor->getData())
260  if (jvPartialIndex == rule.first)
261  global.crossSum(rule.second);
262  }
263 
264  if constexpr(global_interface<Global>::endFactorCrossSum)
265  global.endFactorCrossSum();
266  }
267 
268  if constexpr(global_interface<Global>::endCrossSum)
269  global.endCrossSum();
270  }
271 
272  bool isValidNewFactor = true;
273  if constexpr(global_interface<Global>::isValidNewFactor)
274  isValidNewFactor = global.isValidNewFactor();
275 
276  // If the new Factor is good, we save it together with the joint
277  // value that has produced it (minus the one of the variable to
278  // remove). If it has no neighbors, we add it to the finalFactors
279  // instead.
280  if (isValidNewFactor) {
281  if (!isFinalFactor) {
282  auto & oldRules = *oldRulesP;
283 
284  // If we care enough to merge, we store all rules in
285  // lexicographical order of value; if the old rules already
286  // contained this same value and we are provided with a
287  // merge function, we can merge the two, otherwise we
288  // insert it as-is in the correct spot.
289  if constexpr(global_interface<Global>::mergeFactors) {
290  while (oldRulesCurrId < oldRules.size() && oldRules[oldRulesCurrId].first < jvID)
291  ++oldRulesCurrId;
292 
293  if (oldRulesCurrId < oldRules.size() && oldRules[oldRulesCurrId].first == jvID) {
294  global.mergeFactors(oldRules[oldRulesCurrId].second, std::move(global.newFactor));
295  } else {
296  oldRules.emplace(std::begin(oldRules) + oldRulesCurrId, jvID, std::move(global.newFactor));
297  }
298  } else {
299  // Otherwise we simply append, as it should be faster.
300  // Remember, a factor may be appended on multiple
301  // times, but it's only iterated over once before being
302  // removed.
303  oldRules.emplace_back(jvID, std::move(global.newFactor));
304  }
305  ++oldRulesCurrId;
306  }
307  else
308  finalFactors.push_back(std::move(global.newFactor));
309  }
310  ++jvID;
311  jointValues.advance();
312  }
313 
314  // And finally we remove the variable from the graph.
315  graph.erase(v);
316  }
317 }
318 
319 #endif
AIToolbox::Factored::GenericVariableElimination::operator()
void operator()(const Factors &V, Graph &graph, Global &global)
This operator performs the Variable Elimination operation on the inputs.
Definition: GenericVariableElimination.hpp:175
AIToolbox::Factored::removeFactor
PartialFactors removeFactor(const PartialFactors &pf, size_t f)
This function removes the specified factor from the input PartialFactors.
MEMBER_CHECK
#define MEMBER_CHECK(name, retval, input)
Definition: GenericVariableElimination.hpp:127
AIToolbox::Factored::FactorGraph::Variables
PartialKeys Variables
Definition: FactorGraph.hpp:33
AIToolbox::Factored::FactorGraph::FactorItList
std::vector< FactorIt > FactorItList
Definition: FactorGraph.hpp:50
AIToolbox::Factored::FactorGraph::variableSize
size_t variableSize() const
This function returns the number of variables still in the graph.
Definition: FactorGraph.hpp:429
AIToolbox::Factored::toIndexPartial
size_t toIndexPartial(const PartialKeys &ids, const Factors &space, const Factors &f)
This function converts the input factor in the input space to an unique index.
AIToolbox::Factored::FactorGraph
This class offers a minimal interface to manager a factor graph.
Definition: FactorGraph.hpp:31
ARG
#define ARG(...)
Definition: GenericVariableElimination.hpp:116
AIToolbox::Factored::GenericVariableElimination
This class represents the Variable Elimination algorithm.
Definition: GenericVariableElimination.hpp:72
AIToolbox::Factored::Factors
std::vector< size_t > Factors
Definition: Types.hpp:62
AIToolbox::Factored::GenericVariableElimination::Rules
std::vector< Rule > Rules
Definition: GenericVariableElimination.hpp:75
Core.hpp
AIToolbox::Factored::FactorGraph::bestVariableToRemove
size_t bestVariableToRemove(const Factors &F) const
This function returns the variable which is the cheapest to remove with GenericVariableElimination.
Definition: FactorGraph.hpp:446
AIToolbox::Factored
Definition: GraphUtils.hpp:12
AIToolbox::Factored::GenericVariableElimination::Graph
FactorGraph< Rules > Graph
Definition: GenericVariableElimination.hpp:76
AIToolbox::Factored::GenericVariableElimination::FinalFactors
std::vector< Factor > FinalFactors
Definition: GenericVariableElimination.hpp:77
FactorGraph.hpp
AIToolbox::Factored::PartialFactorsEnumerator
This class enumerates all possible values for a PartialFactors.
Definition: Core.hpp:531
AIToolbox::Factored::GenericVariableElimination::Rule
std::pair< size_t, Factor > Rule
Definition: GenericVariableElimination.hpp:74
AI_SEVERITY_INFO
#define AI_SEVERITY_INFO
Definition: Logging.hpp:69
AIToolbox::Impl::callFunction
void callFunction(F f, Args &&...args)
This function calls the input function with the specified arguments.
Definition: FunctionMatching.hpp:163
Logging.hpp
AI_LOGGER
#define AI_LOGGER(SEV, ARGS)
Definition: Logging.hpp:114
FunctionMatching.hpp
AI_SEVERITY_DEBUG
#define AI_SEVERITY_DEBUG
Definition: Logging.hpp:68