32 TableFactor::TableFactor() {}
50 double denom =
table.size();
53 denominators_.insert(std::pair<Key, double>(dkey.first, denom));
79 size_t cardinalityProduct = 1;
80 for (
auto&& [
_,
c] :
dt.cardinalities()) {
81 cardinalityProduct *=
c;
85 dt.visit([&nrValues](
double x) {
86 if (
x > 0) nrValues += 1;
90 KeySet allKeys(
dt.keys().begin(),
dt.keys().end());
93 std::map<Key, size_t> denominators;
94 double denom = sparseTable.size();
97 denominators.
insert(std::pair<Key, double>(dkey.first, denom));
116 for (
auto&& [
k,
_] : assignment) {
117 assignmentKeys.insert(
k);
122 std::set_difference(allKeys.begin(), allKeys.end(),
123 assignmentKeys.begin(), assignmentKeys.end(),
124 std::back_inserter(diff));
128 for (
auto&&
key : diff) {
129 extras.push_back({
key,
dt.cardinality(
key)});
133 for (
auto&& extra : extraAssignments) {
136 updatedAssignment.
insert(extra);
141 for (
auto&& it = updatedAssignment.rbegin();
142 it != updatedAssignment.rend(); it++) {
143 idx += it->second * denominators.at(it->first);
177 if (
table.size() != max_size) {
178 throw std::runtime_error(
179 "The cardinalities of the keys don't match the number of values in the "
191 sparse_table.pruned();
198 const std::string&
table) {
200 std::vector<double> ys;
201 std::istringstream iss(
table);
202 std::copy(std::istream_iterator<double>(iss), std::istream_iterator<double>(),
203 std::back_inserter(ys));
224 idx += card *
values.at(it->first);
235 for (
auto it =
keys_.rbegin(); it !=
keys_.rend(); ++it) {
237 idx += card *
values.at(*it);
263 if (
auto tf = std::dynamic_pointer_cast<TableFactor>(
f)) {
265 result = std::make_shared<TableFactor>(this->
operator*(*tf));
267 }
else if (
auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(
f)) {
279 result = std::make_shared<DecisionTreeFactor>(
280 f->operator*(
this->toDecisionTreeFactor()));
288 if (
auto tf = std::dynamic_pointer_cast<TableFactor>(
f)) {
289 return std::make_shared<TableFactor>(this->
operator/(*tf));
290 }
else if (
auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(
f)) {
291 return std::make_shared<TableFactor>(
292 this->
operator/(
TableFactor(f->discreteKeys(), *dtf)));
295 return std::make_shared<TableFactor>(this->
operator/(divisor));
304 if (dkeys.size() == 0) {
315 table[it.index()] = it.value();
326 if (parent_keys.empty())
return *
this;
331 for (
auto it =
keys_.rbegin(); it !=
keys_.rend(); ++it) {
332 if (parent_assign.find(*it) != parent_assign.end()) {
333 unique += parent_assign.at(*it) * card;
340 std::sort(parent_keys.begin(), parent_keys.end());
342 parent_keys.begin(), parent_keys.end(),
343 std::back_inserter(child_dkeys));
348 child_card *= child_dkey.second;
350 child_sparse_table_.
reserve(child_card);
357 if (parent_unique == unique) {
359 child_sparse_table_.
insert(idx) = it.value();
363 child_sparse_table_.pruned();
365 return TableFactor(child_dkeys, child_sparse_table_);
373 return (
a == 0 ||
b == 0) ? 0 : (
a /
b);
382 cout <<
" ]" << endl;
385 for (
auto&& kv : assignment) {
386 cout <<
"(" <<
formatter(kv.first) <<
", " << kv.second <<
")";
388 cout <<
" | " << std::setw(10) <<
std::left << it.value() <<
" | "
389 << it.index() << endl;
406 double max_value = std::numeric_limits<double>::lowest();
408 max_value =
std::max(max_value, it.value());
433 sparse_table.
coeffRef(it.index()) = op(it.value());
437 sparse_table.pruned();
453 sparse_table.
coeffRef(it.index()) = op(assignment, it.value());
457 sparse_table.pruned();
466 else if (
f.keys_.empty() &&
f.sparse_table_.nonZeros() == 0)
473 unordered_map<uint64_t, AssignValList> map_f =
474 f.createMap(contract_dkeys, f_free_dkeys);
477 for (
auto u_dkey : union_dkeys) card *= u_dkey.second;
479 mult_sparse_table.
reserve(card);
483 if (map_f.find(contract_unique) == map_f.end())
continue;
484 for (
auto assignVal : map_f[contract_unique]) {
486 mult_sparse_table.
insert(union_idx) = op(it.value(), assignVal.second);
490 mult_sparse_table.pruned();
493 return TableFactor(union_dkeys, mult_sparse_table);
501 f.sorted_dkeys_.begin(),
f.sorted_dkeys_.end(),
502 back_inserter(contract));
511 f.sorted_dkeys_.begin(),
f.sorted_dkeys_.end(),
512 back_inserter(free));
521 f.sorted_dkeys_.end(), back_inserter(union_dkeys));
530 for (
auto it = union_keys.rbegin(); it != union_keys.rend(); it++) {
531 if (f_free.find(it->first) == f_free.end()) {
534 union_idx += f_free.at(it->first) * card;
545 unordered_map<uint64_t, AssignValList> map_f;
552 for (
auto&
key : free)
555 if (map_f.find(unique_rep) == map_f.end()) {
556 map_f[unique_rep] = {make_pair(free_assignments, it.value())};
558 map_f[unique_rep].push_back(make_pair(free_assignments, it.value()));
567 if (dkeys.empty())
return 0;
569 for (
auto it = dkeys.rbegin(); it != dkeys.rend(); it++) {
578 if (assignments.empty())
return 0;
580 for (
auto it = assignments.rbegin(); it != assignments.rend(); it++) {
581 unique_rep += it->second * card;
599 if (nrFrontals >
size()) {
600 throw invalid_argument(
601 "TableFactor::combine: invalid number of frontal "
603 to_string(nrFrontals) +
", nr.keys=" + std::to_string(
size()));
608 for (
auto i = nrFrontals;
i <
keys_.size();
i++) {
618 double new_val = op(combined_table.
coeff(idx), it.value());
619 combined_table.
coeffRef(idx) = new_val;
622 combined_table.pruned();
624 return std::make_shared<TableFactor>(remain_dkeys, combined_table);
630 if (frontalKeys.size() >
size()) {
631 throw invalid_argument(
632 "TableFactor::combine: invalid number of frontal "
634 std::to_string(frontalKeys.size()) +
635 ", nr.keys=" + std::to_string(
size()));
641 if (std::find(frontalKeys.begin(), frontalKeys.end(),
key) ==
653 double new_val = op(combined_table.
coeff(idx), it.value());
654 combined_table.
coeffRef(idx) = new_val;
657 combined_table.pruned();
659 return std::make_shared<TableFactor>(remain_dkeys, combined_table);
671 std::vector<std::pair<Key, size_t>> pairs =
discreteKeys();
673 std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
676 std::vector<std::pair<DiscreteValues, double>>
result;
677 for (
const auto& assignment : assignments) {
678 result.emplace_back(assignment,
operator()(assignment));
692 ss << keyFormatter(
key) <<
"|";
698 for (
size_t j = 0;
j <
size();
j++)
ss <<
":-:|";
706 size_t index = assignment.at(
key);
709 ss << it.value() <<
"|\n";
720 ss <<
"<div>\n<table class='TableFactor'>\n <thead>\n";
725 ss <<
"<th>" << keyFormatter(
key) <<
"</th>";
727 ss <<
"<th>value</th></tr>\n";
730 ss <<
" </thead>\n <tbody>\n";
737 size_t index = assignment.at(
key);
740 ss <<
"<td>" << it.value() <<
"</td>";
743 ss <<
" </tbody>\n</table>\n</div>";
749 const size_t N = maxNrAssignments;
752 vector<pair<Eigen::Index, double>> probabilities;
756 probabilities.emplace_back(it.index(), it.value());
760 if (probabilities.size() <=
N)
return *
this;
763 sort(probabilities.begin(), probabilities.end(),
764 [](
const std::pair<Eigen::Index, double>&
a,
765 const std::pair<Eigen::Index, double>&
b) {
766 return a.second > b.second;
770 if (probabilities.size() >
N) probabilities.resize(
N);
774 pruned_vec.
reserve(probabilities.size());
777 for (
const auto& prob : probabilities) {
778 pruned_vec.
insert(prob.first) = prob.second;
788 throw std::runtime_error(
"TableFactor::restrict not implemented");