Skip to content

Commit

Permalink
[Autodiff] Deterministic gradient compute (apache#7321)
Browse files Browse the repository at this point in the history
* fix unstable compute

* fix

* fix

* lint

* sort linear equation

* sort inequalities

* fix

* fix find

* lint

* fix find

* lint
  • Loading branch information
hzfan authored and electriclilies committed Feb 18, 2021
1 parent e4ea17d commit 4c0fdf9
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 43 deletions.
9 changes: 4 additions & 5 deletions src/arith/solve_linear_equation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,10 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol

// We have to transform ranges of the old variables into relations over new variables because
// new ranges are not enough usually.
for (const auto& p : system_to_solve->ranges) {
const Var& old_var = p.first;
const Range& old_range = p.second;
if (old_to_new_map.count(old_var)) {
PrimExpr express_by_new_vars = old_to_new_map[old_var];
for (const auto& old_var : system_to_solve->variables) {
if (system_to_solve->ranges.find(old_var) != system_to_solve->ranges.end()) {
const Range& old_range = system_to_solve->ranges.at(old_var);
PrimExpr express_by_new_vars = old_to_new_map.at(old_var);
PrimExpr lower_cond = analyzer_solution.Simplify(old_range->min <= express_by_new_vars);
PrimExpr upper_cond =
analyzer_solution.Simplify(express_by_new_vars < old_range->min + old_range->extent);
Expand Down
54 changes: 27 additions & 27 deletions src/arith/solve_linear_inequality.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,10 @@ struct ExprLess {
}
};

void DebugPrint(
const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& current_ineq_set,
const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& next_ineq_set,
const std::vector<PrimExpr>& rest, const std::vector<std::pair<int64_t, PrimExpr>>& coef_pos,
const std::vector<std::pair<int64_t, PrimExpr>>& coef_neg) {
void DebugPrint(const std::vector<PrimExpr>& current_ineq_set,
const std::vector<PrimExpr>& next_ineq_set, const std::vector<PrimExpr>& rest,
const std::vector<std::pair<int64_t, PrimExpr>>& coef_pos,
const std::vector<std::pair<int64_t, PrimExpr>>& coef_neg) {
std::cout << "Current ineq set:\n[";
for (auto& ineq : current_ineq_set) {
std::cout << ineq << ", ";
Expand Down Expand Up @@ -148,9 +147,12 @@ class NormalizeComparisons : public ExprMutator {
arith::Analyzer analyzer_;
};

void AddInequality(std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* inequality_set,
const PrimExpr& new_ineq, Analyzer* analyzer) {
if (analyzer->CanProve(new_ineq) || inequality_set->find(new_ineq) != inequality_set->end()) {
void AddInequality(std::vector<PrimExpr>* inequality_set, const PrimExpr& new_ineq,
Analyzer* analyzer) {
if (analyzer->CanProve(new_ineq) ||
std::find_if(inequality_set->begin(), inequality_set->end(), [&](const PrimExpr& e) {
return StructuralEqual()(e, new_ineq);
}) != inequality_set->end()) {
// redundant: follows from the vranges
// or has already been added
return;
Expand All @@ -168,15 +170,13 @@ void AddInequality(std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>
}
}

inequality_set->insert(new_ineq);
inequality_set->push_back(new_ineq);
}

void ClassifyByPolarity(
const Var& var,
const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& current_ineq_set,
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* next_ineq_set,
std::vector<PrimExpr>* rest, std::vector<std::pair<int64_t, PrimExpr>>* coef_pos,
std::vector<std::pair<int64_t, PrimExpr>>* coef_neg, Analyzer* analyzer) {
void ClassifyByPolarity(const Var& var, const std::vector<PrimExpr>& current_ineq_set,
std::vector<PrimExpr>* next_ineq_set, std::vector<PrimExpr>* rest,
std::vector<std::pair<int64_t, PrimExpr>>* coef_pos,
std::vector<std::pair<int64_t, PrimExpr>>* coef_neg, Analyzer* analyzer) {
// Take formulas from current_ineq_set and classify them according to polarity wrt var
// and store to coef_pos and coef_neg respectively.
for (const PrimExpr& ineq : current_ineq_set) {
Expand Down Expand Up @@ -218,14 +218,14 @@ void ClassifyByPolarity(
}
}

void MoveEquality(std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* upper_bounds,
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* lower_bounds,
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* equalities) {
void MoveEquality(std::vector<PrimExpr>* upper_bounds, std::vector<PrimExpr>* lower_bounds,
std::vector<PrimExpr>* equalities) {
// those exist in both upper & lower bounds will be moved to equalities
for (auto ub = upper_bounds->begin(); ub != upper_bounds->end();) {
auto lb = lower_bounds->find(*ub);
auto lb = std::find_if(lower_bounds->begin(), lower_bounds->end(),
[&](const PrimExpr& e) { return StructuralEqual()(e, *ub); });
if (lb != lower_bounds->end()) {
equalities->insert(*lb);
equalities->push_back(*lb);
lower_bounds->erase(lb);
ub = upper_bounds->erase(ub);
} else {
Expand All @@ -249,8 +249,8 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
// and move to the next variable.

// normalized inequality
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> current_ineq_set_to_solve;
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> next_ineq_set_to_solve;
std::vector<PrimExpr> current_ineq_set_to_solve;
std::vector<PrimExpr> next_ineq_set_to_solve;
// A vector of pairs (c, e), c > 0, representing formulas of the form c*v + e <= 0
std::vector<std::pair<int64_t, PrimExpr>> coef_pos;
// A vector of pairs (c, e), c < 0, representing formulas of the form c*v + e <= 0
Expand Down Expand Up @@ -321,8 +321,8 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
}

// The resulting lower and upper bounds
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> upper_bounds;
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> lower_bounds;
std::vector<PrimExpr> upper_bounds;
std::vector<PrimExpr> lower_bounds;
upper_bounds.reserve(coef_pos.size());
lower_bounds.reserve(coef_neg.size());

Expand All @@ -345,7 +345,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
}
}
// Add the upper bound
upper_bounds.insert(bound);
upper_bounds.push_back(bound);
}
for (const auto& neg : coef_neg) {
PrimExpr bound = make_const(v.dtype(), -coef_lcm / neg.first) * neg.second;
Expand All @@ -366,10 +366,10 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
}
}
// Add the lower bound
lower_bounds.insert(bound);
lower_bounds.push_back(bound);
}

std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> equal;
std::vector<PrimExpr> equal;
equal.reserve(std::min(upper_bounds.size(), lower_bounds.size()));
MoveEquality(&upper_bounds, &lower_bounds, &equal);
std::vector<PrimExpr> equal_list(equal.begin(), equal.end());
Expand Down
26 changes: 15 additions & 11 deletions src/te/autodiff/ad_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,15 +413,17 @@ class FactorOutAtomicFormulasFunctor
auto res_b = VisitExpr(op->b);

// For the And case we return the union of the sets of atomic formulas
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size());
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set;
res_a_set.reserve(res_a.atomic_formulas.size());
std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
std::inserter(res_set, res_set.end()));
std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
std::inserter(res_set, res_set.end()));

std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
std::inserter(res_a_set, res_a_set.end()));

std::vector<PrimExpr> res = res_a.atomic_formulas;
for (const auto& e : res_b.atomic_formulas) {
if (res_a_set.find(e) == res_a_set.end()) {
res.emplace_back(e);
}
}
// And the residuals are combined with &&
return {res, res_a.rest && res_b.rest};
}
Expand All @@ -443,32 +445,34 @@ class FactorOutAtomicFormulasFunctor

// For the Or case we intersect the sets of atomic formulas
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
std::vector<PrimExpr> res;
res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
for (const auto& res_b_formula : res_b_set) {
res.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
for (const auto& res_b_formula : res_b.atomic_formulas) {
if (res_a_set.count(res_b_formula)) {
res_set.insert(res_b_formula);
res.push_back(res_b_formula);
}
}

// Computing the residual is more complex: we have to compute the sets of atomic formulas
// which are left behind, and then combine them with the residuals into the new residual.
std::vector<PrimExpr> new_cond_a;
new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
for (const auto& formula : res_a_set) {
for (const auto& formula : res_a.atomic_formulas) {
if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
}

std::vector<PrimExpr> new_cond_b;
new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
for (const auto& formula : res_b_set) {
for (const auto& formula : res_b.atomic_formulas) {
if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
}

res_a.atomic_formulas = std::move(new_cond_a);
res_b.atomic_formulas = std::move(new_cond_b);

PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
std::vector<PrimExpr> res{res_set.begin(), res_set.end()};

return {res, new_rest};
}
Expand Down

0 comments on commit 4c0fdf9

Please sign in to comment.