1 #ifndef AI_TOOLBOX_FACTORED_GENERIC_VARIABLE_ELIMINATION_HEADER_FILE
2 #define AI_TOOLBOX_FACTORED_GENERIC_VARIABLE_ELIMINATION_HEADER_FILE
71 template <
typename Factor>
74 using Rule = std::pair<size_t, Factor>;
75 using Rules = std::vector<Rule>;
86 template <
typename Global>
96 struct global_interface;
107 template <
typename Global>
111 template <
typename Factor>
112 template <
typename M>
115 #define STR(X) STR2(X)
116 #define ARG(...) __VA_ARGS__
127 #define MEMBER_CHECK(name, retval, input) \
131 template <typename Z> static constexpr auto name##Check(char) -> decltype( \
132 static_cast<retval (Z::*)(input)> (&Z::name), \
135 template <typename Z> static constexpr auto name##Check(int) -> decltype( \
136 static_cast<retval (Z::*)(input) const> (&Z::name), \
139 template <typename Z> static constexpr auto name##Check(long) -> decltype( \
143 static_assert(Impl::is_compatible_f< \
144 decltype(&Z::name), \
146 >::value, "You provide a member '" STR(name) "' but with the wrong signature."); \
149 template <typename Z> static constexpr auto name##Check(...) -> bool { return false; } \
153 name = name##Check<M>('\0') \
173 template <
typename Factor>
174 template <
typename Global>
176 static_assert(global_interface<Global>::crossSum,
"You must provide a crossSum method!");
177 static_assert(global_interface<Global>::makeResult,
"You must provide a makeResult method!");
178 static_assert(std::is_same_v<Factor, decltype(global.newFactor)>,
"You must provide a public 'Factor newFactor;' member!");
187 global.makeResult(std::move(finalFactors));
190 template <
typename Factor>
191 template <
typename Global>
197 const auto & factors = graph.getFactors(v);
198 const auto & vNeighbors = graph.getVariables(v);
200 if constexpr(global_interface<Global>::beginRemoval)
205 const bool isFinalFactor = vNeighbors.size() == 0;
208 size_t oldRulesCurrId = 0;
211 const auto id = jointValues.getFactorToSkipId();
213 if (!isFinalFactor) {
214 oldRulesP = &graph.getFactor(vNeighbors)->getData();
215 oldRulesP->reserve(jointValues.size());
220 "Width of this factor: " << vNeighbors.size() + 1 <<
". "
221 "Joint values to iterate: " << jointValues.size() * V[v]
225 while (jointValues.isValid()) {
226 auto & jointValue = *jointValues;
228 if constexpr(global_interface<Global>::initNewFactor)
229 global.initNewFactor();
234 for (
size_t vValue = 0; vValue < V[v]; ++vValue) {
235 if constexpr(global_interface<Global>::beginCrossSum)
238 jointValue.second[id] = vValue;
239 for (
const auto factor : factors) {
240 if constexpr(global_interface<Global>::beginFactorCrossSum)
241 global.beginFactorCrossSum();
245 const size_t jvPartialIndex =
toIndexPartial(factor->getVariables(), V, jointValue);
246 if constexpr(global_interface<Global>::mergeFactors) {
247 const auto & data = factor->getData();
248 const auto ruleIt = std::lower_bound(
252 [](
const Rule & lhs,
const size_t rhs) {
253 return lhs.first < rhs;
256 if (ruleIt != std::end(data) && ruleIt->first == jvPartialIndex)
257 global.crossSum(ruleIt->second);
259 for (
const auto & rule : factor->getData())
260 if (jvPartialIndex == rule.first)
261 global.crossSum(rule.second);
264 if constexpr(global_interface<Global>::endFactorCrossSum)
265 global.endFactorCrossSum();
268 if constexpr(global_interface<Global>::endCrossSum)
269 global.endCrossSum();
272 bool isValidNewFactor =
true;
273 if constexpr(global_interface<Global>::isValidNewFactor)
274 isValidNewFactor = global.isValidNewFactor();
280 if (isValidNewFactor) {
281 if (!isFinalFactor) {
282 auto & oldRules = *oldRulesP;
289 if constexpr(global_interface<Global>::mergeFactors) {
290 while (oldRulesCurrId < oldRules.size() && oldRules[oldRulesCurrId].first < jvID)
293 if (oldRulesCurrId < oldRules.size() && oldRules[oldRulesCurrId].first == jvID) {
294 global.mergeFactors(oldRules[oldRulesCurrId].second, std::move(global.newFactor));
296 oldRules.emplace(std::begin(oldRules) + oldRulesCurrId, jvID, std::move(global.newFactor));
303 oldRules.emplace_back(jvID, std::move(global.newFactor));
308 finalFactors.push_back(std::move(global.newFactor));
311 jointValues.advance();