Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EGRAPH 2023 Eval #10

Open
wants to merge 18 commits into
base: 3la-pldi-flexmatch-eval
Choose a base branch
from
Prev Previous commit
Next Next commit
compare with removing redaundant constraints
  • Loading branch information
AD1024 committed Mar 13, 2023
commit 6bd17d345c4d7dbc06098763a93beb85faf67356
5 changes: 5 additions & 0 deletions flexmatch/draw_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ def draw_extraction_times(data: dict):
ax.bar(x + bar_width, solver_times['ILP-Topo'], bar_width, label='ILP-Topo', color='lightblue', edgecolor='grey')
overhead= ax.bar(x + bar_width, overhead_times['ILP-Topo'], bar_width, bottom=solver_times['ILP-Topo'], color='slateblue', hatch='//', edgecolor='grey')

print((np.array(solver_times['ILP-Topo']) + overhead_times['ILP-Topo']) / (np.array(solver_times['ILP-ACyc']) + overhead_times['ILP-ACyc']))
print((np.array(solver_times['ILP-Topo']) + overhead_times['ILP-Topo']) / (np.array(solver_times['WPMAXSAT']) + overhead_times['WPMAXSAT']))
# print(solver_times['ILP-ACyc'])
# print(solver_times['WPMAXSAT'])

ax.set_xticks(x, data.keys())
ax.set_xticklabels(x_ticks)

Expand Down
54 changes: 37 additions & 17 deletions flexmatch/src/ilp_extract.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::{collections::HashMap, time::Instant};
use std::{
collections::{HashMap, HashSet},
time::Instant,
};

use egg::{Analysis, EGraph, Id, Language, RecExpr};
use rand::Rng;
Expand All @@ -11,6 +14,7 @@ fn get_all_cycles<'a, L, N>(
path: &mut Vec<(Id, L)>,
problem: &mut Problem<'a>,
node_vars: &HashMap<L, usize>,
node_to_children: &HashMap<usize, HashSet<Id>>,
) where
L: Language,
N: Analysis<L>,
Expand All @@ -22,8 +26,8 @@ fn get_all_cycles<'a, L, N>(
if let Some((idx, _)) = path.iter().enumerate().find(|(_, (id, _))| id == root) {
let mut new_cycle = Vec::new();
let subpath = path[idx..].to_vec();
for (_, n) in subpath {
new_cycle.push(node_vars[&n]);
for (_, n) in &subpath {
new_cycle.push(node_vars[n]);
}
let mut rng = rand::thread_rng();
if new_cycle.len() == 1 {
Expand All @@ -35,18 +39,21 @@ fn get_all_cycles<'a, L, N>(
constraint.add_wvar(WeightedVariable::new_idx(new_cycle[0], 1.0));
problem.add_constraint(constraint).unwrap();
} else {
for node_idx in egraph[*root].nodes.iter().map(|x| node_vars[x]) {
new_cycle[0] = node_idx;
// sum up <= len(new_cycle) - 1
let mut constraint = rplex::Constraint::new(
rplex::ConstraintType::LessThanEq,
new_cycle.len() as f64 - 1.0,
format!("cycle_{}_{}", root, rng.gen::<u64>()),
);
for node_idx in new_cycle.iter() {
constraint.add_wvar(WeightedVariable::new_idx(*node_idx, 1.0));
let nxt_hop = subpath[1].0;
for node_idx in egraph[*root].nodes.iter().map(|n| node_vars[n]) {
if node_to_children[&node_idx].contains(&nxt_hop) {
new_cycle[0] = node_idx;
// sum up <= len(new_cycle) - 1
let mut constraint = rplex::Constraint::new(
rplex::ConstraintType::LessThanEq,
new_cycle.len() as f64 - 1.0,
format!("cycle_{}_{}", root, rng.gen::<u64>()),
);
for node_idx in new_cycle.iter() {
constraint.add_wvar(WeightedVariable::new_idx(*node_idx, 1.0));
}
problem.add_constraint(constraint).unwrap();
}
problem.add_constraint(constraint).unwrap();
}
}
return;
Expand All @@ -55,11 +62,19 @@ fn get_all_cycles<'a, L, N>(
}
color.insert(*root, 1);
for node in egraph[*root].nodes.iter() {
path.push((*root, node.clone()));
for ch in node.children() {
path.push((*root, node.clone()));
get_all_cycles(egraph, ch, color, path, problem, node_vars);
path.pop();
get_all_cycles(
egraph,
ch,
color,
path,
problem,
node_vars,
node_to_children,
);
}
path.pop();
}
color.insert(*root, 2);
}
Expand Down Expand Up @@ -203,12 +218,15 @@ where
constraint.add_wvar(WeightedVariable::new_idx(node_idx, 1.0));
}
problem.add_constraint(constraint).unwrap();
let mut node_to_children = HashMap::new();

// children constraint
for eclass in egraph.classes() {
for node in egraph[eclass.id].nodes.iter() {
let node_idx = node_vars[node];
let mut node_children_set = HashSet::new();
for (ch_idx, ch) in node.children().iter().enumerate() {
node_children_set.insert(*ch);
let mut constraint = rplex::Constraint::new(
rplex::ConstraintType::GreaterThanEq,
0.0,
Expand All @@ -220,6 +238,7 @@ where
}
problem.add_constraint(constraint).unwrap();
}
node_to_children.insert(node_idx, node_children_set);
}
}

Expand Down Expand Up @@ -260,6 +279,7 @@ where
&mut path,
&mut problem,
&node_vars,
&node_to_children,
);
}
}
Expand Down
11 changes: 10 additions & 1 deletion flexmatch/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,16 @@ fn main() {
};

let aggregated_configs = read_configs(&flexmatch_home, config_files);
let mut rewrites = vec![];
let mut rewrites = vec![
// glenside::language::rewrites::collapse_nested_transposes(),
// glenside::language::rewrites::simplify_multiple_transposes(),
// glenside::language::rewrites::simplify_multiple_accesses(),
// glenside::language::rewrites::simplify_reduce_max(),
// glenside::language::rewrites::bubble_access_reshape_through_compute_reduce_max(),
// glenside::language::rewrites::simplify_multiple_access_reshapes(),
// glenside::language::rewrites::collapse_nested_accesses(),
// glenside::language::rewrites::flatten_dot_product_to_dense(),
];
let mut rewrite_set = HashSet::new();
debug!("{:?}", aggregated_configs);
for config in aggregated_configs.iter() {
Expand Down
26 changes: 22 additions & 4 deletions flexmatch/src/maxsat_extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ fn get_all_cycles<L, N>(
path: &mut Vec<(Id, L)>,
problem_writer: &mut ProblemWriter,
node_vars: &HashMap<L, usize>,
node_to_children: &HashMap<usize, HashSet<Id>>,
top: f64,
) where
L: Language,
Expand All @@ -32,15 +33,18 @@ fn get_all_cycles<L, N>(
if let Some((idx, _)) = path.iter().enumerate().find(|(_, (id, _))| id == root) {
let mut new_cycle = Vec::new();
let subpath = path[idx..].to_vec();
for (_, n) in subpath {
for (_, n) in &subpath {
new_cycle.push(node_vars[&n]);
}
if new_cycle.len() == 1 {
problem_writer.hard_clause(&format!("-{}", new_cycle[0]), top);
} else {
let nxt_hop = subpath[1].0;
for node_idx in egraph[*root].nodes.iter().map(|x| node_vars[x]) {
new_cycle[0] = node_idx;
disjuct_negative(&new_cycle, problem_writer, top);
// if node_to_children[&node_idx].contains(&nxt_hop) {
new_cycle[0] = node_idx;
disjuct_negative(&new_cycle, problem_writer, top);
// }
}
}
return;
Expand All @@ -53,7 +57,16 @@ fn get_all_cycles<L, N>(
// let mut to_here = path.clone();
// to_here.push((*root, node.clone()));
path.push((*root, node.clone()));
get_all_cycles(egraph, ch, color, path, problem_writer, node_vars, top);
get_all_cycles(
egraph,
ch,
color,
path,
problem_writer,
node_vars,
node_to_children,
top,
);
path.pop();
}
}
Expand Down Expand Up @@ -354,18 +367,22 @@ where
.join(" ");
hard_clauses.push(root_clause);

let mut node_to_children = HashMap::new();
// children constraint
for c in self.egraph.classes() {
for n in c.nodes.iter() {
// v_n -> \bigvee_cN v_cN forall C
let mut node_children = HashSet::new();
for ch in n.children() {
node_children.insert(*ch);
let mut clause = String::new();
clause.push_str(&format!("-{}", node_vars[n]));
for ch_node in self.egraph[*ch].nodes.iter() {
clause.push_str(&format!(" {}", node_vars[ch_node]));
}
hard_clauses.push(clause);
}
node_to_children.insert(node_vars[n], node_children);
}
}

Expand All @@ -380,6 +397,7 @@ where
&mut path,
&mut self.writer,
&node_vars,
&node_to_children,
top,
);
}
Expand Down
1 change: 1 addition & 0 deletions flexmatch/src/rewrites.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub fn get_rewrite_from_string(name: &String, args: &Box<[i32]>) -> Rewrite<Lang
"simplify-multiple-access-reshapes" => simplify_multiple_access_reshapes(),
"bubble-access-through-access-transpose" => bubble_access_through_access_transpose(),
"simplify-reduce-max" => simplify_reduce_max(),
"collapse-nested-transposes" => collapse_nested_transposes(),

"flex-linear-rewrite" => linear_layer_accelerator_rewrites(),
"flex-linear-dense" => dot_product_to_linear(),
Expand Down