Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[venom]: avoid last swap for commutative ops #4048

Merged
merged 19 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vyper/venom/analysis/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def analyze(self):
# dfg_inputs of %15 is all the instructions which *use* %15, ex. [(%16 = iszero %15), ...]
for bb in self.function.get_basic_blocks():
for inst in bb.instructions:
operands = inst.get_inputs()
operands = inst.get_input_variables()
res = inst.get_outputs()

for op in operands:
Expand Down
2 changes: 1 addition & 1 deletion vyper/venom/analysis/dup_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def analyze(self):
last_liveness = bb.out_vars
for inst in reversed(bb.instructions):
inst.dup_requirements = OrderedSet()
ops = inst.get_inputs()
ops = inst.get_input_variables()
for op in ops:
if op in last_liveness:
inst.dup_requirements.add(op)
Expand Down
2 changes: 1 addition & 1 deletion vyper/venom/analysis/liveness.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _calculate_liveness(self, bb: IRBasicBlock) -> bool:
orig_liveness = bb.instructions[0].liveness.copy()
liveness = bb.out_vars.copy()
for instruction in reversed(bb.instructions):
ins = instruction.get_inputs()
ins = instruction.get_input_variables()
outs = instruction.get_outputs()

if ins or outs:
Expand Down
4 changes: 2 additions & 2 deletions vyper/venom/basicblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def get_non_label_operands(self) -> Iterator[IROperand]:
"""
return (op for op in self.operands if not isinstance(op, IRLabel))

def get_inputs(self) -> Iterator[IRVariable]:
def get_input_variables(self) -> Iterator[IRVariable]:
"""
Get all input operands for instruction.
"""
Expand Down Expand Up @@ -477,7 +477,7 @@ def get_assignments(self):
def get_uses(self) -> dict[IRVariable, OrderedSet[IRInstruction]]:
uses: dict[IRVariable, OrderedSet[IRInstruction]] = {}
for inst in self.instructions:
for op in inst.get_inputs():
for op in inst.get_input_variables():
if op not in uses:
uses[op] = OrderedSet()
uses[op].add(inst)
Expand Down
2 changes: 1 addition & 1 deletion vyper/venom/passes/dft.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _process_instruction_r(self, bb: IRBasicBlock, inst: IRInstruction, offset:
self.inst_order[inst] = 0
return

for op in inst.get_inputs():
for op in inst.get_input_variables():
target = self.dfg.get_producing_instruction(op)
assert target is not None, f"no producing instruction for {op}"
if target.parent != inst.parent or target.fence_id != inst.fence_id:
Expand Down
2 changes: 1 addition & 1 deletion vyper/venom/passes/remove_unused_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _process_instruction(self, inst):
if len(uses) > 0:
return

for operand in inst.get_inputs():
for operand in inst.get_input_variables():
self.dfg.remove_use(operand, inst)
new_uses = self.dfg.get_uses(operand)
self.work_list.addmany(new_uses)
Expand Down
35 changes: 27 additions & 8 deletions vyper/venom/venom_to_assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@
]
)

COMMUTATIVE_INSTRUCTIONS = frozenset(["add", "mul", "smul", "or", "xor", "and", "eq"])


_REVERT_POSTAMBLE = ["_sym___revert", "JUMPDEST", *PUSH(0), "DUP1", "REVERT"]


Expand Down Expand Up @@ -195,8 +198,14 @@ def generate_evm(self, no_optimize: bool = False) -> list[str]:
return top_asm

def _stack_reorder(
self, assembly: list, stack: StackModel, stack_ops: list[IRVariable]
) -> None:
self, assembly: list, stack: StackModel, stack_ops: list[IROperand], dry_run: bool = False
) -> int:
cost = 0

if dry_run:
assert len(assembly) == 0, "Dry run should not work on assembly"
stack = stack.copy()

stack_ops_count = len(stack_ops)

counts = Counter(stack_ops)
Expand All @@ -216,8 +225,10 @@ def _stack_reorder(
if op == stack.peek(final_stack_depth):
continue

self.swap(assembly, stack, depth)
self.swap(assembly, stack, final_stack_depth)
cost += self.swap(assembly, stack, depth)
cost += self.swap(assembly, stack, final_stack_depth)

return cost

def _emit_input_operands(
self, assembly: list, inst: IRInstruction, ops: list[IROperand], stack: StackModel
Expand Down Expand Up @@ -376,7 +387,7 @@ def _generate_evm_for_instruction(

if opcode == "phi":
ret = inst.get_outputs()[0]
phis = list(inst.get_inputs())
phis = list(inst.get_input_variables())
depth = stack.get_phi_depth(phis)
# collapse the arguments to the phi node in the stack.
# example, for `%56 = %label1 %13 %label2 %14`, we will
Expand Down Expand Up @@ -406,9 +417,16 @@ def _generate_evm_for_instruction(
target_stack_list = list(target_stack)
self._stack_reorder(assembly, stack, target_stack_list)

if opcode in COMMUTATIVE_INSTRUCTIONS:
cost_no_swap = self._stack_reorder([], stack, operands, dry_run=True)
operands[-1], operands[-2] = operands[-2], operands[-1]
cost_with_swap = self._stack_reorder([], stack, operands, dry_run=True)
if cost_with_swap > cost_no_swap:
operands[-1], operands[-2] = operands[-2], operands[-1]

# final step to get the inputs to this instruction ordered
# correctly on the stack
self._stack_reorder(assembly, stack, operands) # type: ignore
self._stack_reorder(assembly, stack, operands)

# some instructions (i.e. invoke) need to do stack manipulations
# with the stack model containing the return value(s), so we fiddle
Expand Down Expand Up @@ -533,13 +551,14 @@ def pop(self, assembly, stack, num=1):
stack.pop(num)
assembly.extend(["POP"] * num)

def swap(self, assembly, stack, depth):
def swap(self, assembly, stack, depth) -> int:
# Swaps of the top is no op
if depth == 0:
return
return 0

stack.swap(depth)
assembly.append(_evm_swap_for(depth))
return 1

def dup(self, assembly, stack, depth):
stack.dup(depth)
Expand Down
Loading