Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[draft] refactor: constant folding #1

Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor constant folder into its own visitor instead of embedding im…
…plementation on AST nodes
charles-cooper committed Jan 7, 2024
commit 7b5da3b0d0934dd80350e4e68d51a63d25ef6d4e
178 changes: 4 additions & 174 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
@@ -407,14 +407,10 @@ def get_folded_value(self) -> "VyperNode":
For constant/literal nodes, the node should be directly returned
without caching to the metadata.
"""
if self.is_literal_value:
return self

if "folded_value" not in self._metadata:
res = self._try_fold() # possibly throws UnfoldableNode
self._set_folded_value(res)

return self._metadata["folded_value"]
try:
return self._metadata["folded_value"]
except KeyError:
raise UnfoldableNode("not foldable", self)

def _set_folded_value(self, node: "VyperNode") -> None:
# sanity check this is only called once
@@ -430,17 +426,6 @@ def _set_folded_value(self, node: "VyperNode") -> None:
def get_original_node(self) -> "VyperNode":
return self._original_node or self

def _try_fold(self) -> "VyperNode":
"""
Attempt to constant-fold the content of a node, returning the result of
constant-folding if possible.
If a node cannot be folded, it should raise `UnfoldableNode`. This
base implementation acts as a catch-all to raise on any inherited
classes that do not implement the method.
"""
raise UnfoldableNode(f"{type(self)} cannot be folded")

def validate(self) -> None:
"""
Validate the content of a node.
@@ -919,10 +904,6 @@ class List(ExprNode):
def is_literal_value(self):
return all(e.is_literal_value for e in self.elements)

def _try_fold(self) -> ExprNode:
elements = [e.get_folded_value() for e in self.elements]
return type(self).from_node(self, elements=elements)


class Tuple(ExprNode):
__slots__ = ("elements",)
@@ -936,10 +917,6 @@ def validate(self):
if not self.elements:
raise InvalidLiteral("Cannot have an empty tuple", self)

def _try_fold(self) -> ExprNode:
elements = [e.get_folded_value() for e in self.elements]
return type(self).from_node(self, elements=elements)


class NameConstant(Constant):
__slots__ = ()
@@ -960,10 +937,6 @@ class Dict(ExprNode):
def is_literal_value(self):
return all(v.is_literal_value for v in self.values)

def _try_fold(self) -> ExprNode:
values = [v.get_folded_value() for v in self.values]
return type(self).from_node(self, values=values)


class Name(ExprNode):
__slots__ = ("id",)
@@ -972,27 +945,6 @@ class Name(ExprNode):
class UnaryOp(ExprNode):
__slots__ = ("op", "operand")

def _try_fold(self) -> ExprNode:
"""
Attempt to evaluate the unary operation.
Returns
-------
Int | Decimal
Node representing the result of the evaluation.
"""
operand = self.operand.get_folded_value()

if isinstance(self.op, Not) and not isinstance(operand, NameConstant):
raise UnfoldableNode("not a boolean!", self.operand)
if isinstance(self.op, USub) and not isinstance(operand, Num):
raise UnfoldableNode("not a number!", self.operand)
if isinstance(self.op, Invert) and not isinstance(operand, Int):
raise UnfoldableNode("not an int!", self.operand)

value = self.op._op(operand.value)
return type(operand).from_node(self, value=value)


class Operator(VyperNode):
pass
@@ -1021,30 +973,6 @@ def _op(self, value):
class BinOp(ExprNode):
__slots__ = ("left", "op", "right")

def _try_fold(self) -> ExprNode:
"""
Attempt to evaluate the arithmetic operation.
Returns
-------
Int | Decimal
Node representing the result of the evaluation.
"""
left, right = [i.get_folded_value() for i in (self.left, self.right)]
if type(left) is not type(right):
raise UnfoldableNode("invalid operation", self)
if not isinstance(left, Num):
raise UnfoldableNode("not a number!", self.left)

# this validation is performed to prevent the compiler from hanging
# on very large shifts and improve the error message for negative
# values.
if isinstance(self.op, (LShift, RShift)) and not (0 <= right.value <= 256):
raise InvalidLiteral("Shift bits must be between 0 and 256", self.right)

value = self.op._op(left.value, right.value)
return type(left).from_node(self, value=value)


class Add(Operator):
__slots__ = ()
@@ -1170,24 +1098,6 @@ class RShift(Operator):
class BoolOp(ExprNode):
__slots__ = ("op", "values")

def _try_fold(self) -> ExprNode:
"""
Attempt to evaluate the boolean operation.
Returns
-------
NameConstant
Node representing the result of the evaluation.
"""
values = [v.get_folded_value() for v in self.values]

if any(not isinstance(v, NameConstant) for v in values):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")

values = [v.value for v in values]
value = self.op._op(values)
return NameConstant.from_node(self, value=value)


class And(Operator):
__slots__ = ()
@@ -1225,40 +1135,6 @@ def __init__(self, *args, **kwargs):
kwargs["right"] = kwargs.pop("comparators")[0]
super().__init__(*args, **kwargs)

def _try_fold(self) -> ExprNode:
"""
Attempt to evaluate the comparison.
Returns
-------
NameConstant
Node representing the result of the evaluation.
"""
left, right = [i.get_folded_value() for i in (self.left, self.right)]
if not isinstance(left, Constant):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")

# CMC 2022-08-04 we could probably remove these evaluation rules as they
# are taken care of in the IR optimizer now.
if isinstance(self.op, (In, NotIn)):
if not isinstance(right, List):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")
if next((i for i in right.elements if not isinstance(i, Constant)), None):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")
if len(set([type(i) for i in right.elements])) > 1:
raise UnfoldableNode("List contains multiple literal types")
value = self.op._op(left.value, [i.value for i in right.elements])
return NameConstant.from_node(self, value=value)

if not isinstance(left, type(right)):
raise UnfoldableNode("Cannot compare different literal types")

if not isinstance(self.op, (Eq, NotEq)) and not isinstance(left, (Int, Decimal)):
raise TypeMismatch(f"Invalid literal types for {self.op.description} comparison", self)

value = self.op._op(left.value, right.value)
return NameConstant.from_node(self, value=value)


class Eq(Operator):
__slots__ = ()
@@ -1315,21 +1191,6 @@ def _op(self, left, right):
class Call(ExprNode):
__slots__ = ("func", "args", "keywords")

# try checking if this is a builtin, which is foldable
def _try_fold(self):
if not isinstance(self.func, Name):
raise UnfoldableNode("not a builtin", self)

# cursed import cycle!
from vyper.builtins.functions import DISPATCH_TABLE

func_name = self.func.id
if func_name not in DISPATCH_TABLE:
raise UnfoldableNode("not a builtin", self)

builtin_t = DISPATCH_TABLE[func_name]
return builtin_t._try_fold(self)


class keyword(VyperNode):
__slots__ = ("arg", "value")
@@ -1342,37 +1203,6 @@ class Attribute(ExprNode):
class Subscript(ExprNode):
__slots__ = ("slice", "value")

def _try_fold(self) -> ExprNode:
"""
Attempt to evaluate the subscript.
This method reduces an indexed reference to a literal array into the value
within the array, e.g. `["foo", "bar"][1]` becomes `"bar"`
Returns
-------
ExprNode
Node representing the result of the evaluation.
"""
slice_ = self.slice.value.get_folded_value()
value = self.value.get_folded_value()

if not isinstance(value, List):
raise UnfoldableNode("Subscript object is not a literal list")

elements = value.elements
if len(set([type(i) for i in elements])) > 1:
raise UnfoldableNode("List contains multiple node types")

if not isinstance(slice_, Int):
raise UnfoldableNode("invalid index type", slice_)

idx = slice_.value
if idx < 0 or idx >= len(elements):
raise UnfoldableNode("invalid index value")

return elements[idx]


class Index(VyperNode):
__slots__ = ("value",)
1 change: 0 additions & 1 deletion vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
@@ -31,7 +31,6 @@ class VyperNode:
@classmethod
def get_fields(cls: Any) -> set: ...
def get_folded_value(self) -> VyperNode: ...
def _try_fold(self) -> VyperNode: ...
def _set_folded_value(self, node: VyperNode) -> None: ...
@classmethod
def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ...
160 changes: 160 additions & 0 deletions vyper/semantics/analysis/constant_folding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from vyper import ast as vy_ast
from vyper.exceptions import InvalidLiteral, UndeclaredDefinition, UnfoldableNode
from vyper.semantics.analysis.common import VyperNodeVisitorBase
from vyper.semantics.namespace import get_namespace


class ConstantFolder(VyperNodeVisitorBase):
def visit(self, node):
for c in node.get_children():
try:
self.visit(c)
except UnfoldableNode:
# ignore bubbled up exceptions
pass

try:
for class_ in node.__class__.mro():
ast_type = class_.__name__

visitor_fn = getattr(self, f"visit_{ast_type}", None)
if visitor_fn:
folded_value = visitor_fn(node)
node._set_folded_value(folded_value)
return folded_value
else:
raise UnfoldableNode
except UnfoldableNode:
# ignore bubbled up exceptions
pass

def visit_Constant(self, node) -> vy_ast.ExprNode:
return node

def visit_Name(self, node) -> vy_ast.ExprNode:
namespace = get_namespace()
try:
ret = namespace[node]
except UndeclaredDefinition:
raise UnfoldableNode("unknown name", node)

if not isinstance(ret, vy_ast.VariableDecl) and not ret.is_constant:
raise UnfoldableNode("not a constant", node)

return ret.value.get_folded_value()

def visit_UnaryOp(self, node):
operand = node.operand.get_folded_value()

if isinstance(node.op, vy_ast.Not) and not isinstance(operand, vy_ast.NameConstant):
raise UnfoldableNode("not a boolean!", node.operand)
if isinstance(node.op, vy_ast.USub) and not isinstance(operand, vy_ast.Num):
raise UnfoldableNode("not a number!", node.operand)
if isinstance(node.op, vy_ast.Invert) and not isinstance(operand, vy_ast.Int):
raise UnfoldableNode("not an int!", node.operand)

value = node.op._op(operand.value)
return type(operand).from_node(node, value=value)

def visit_BinOp(self, node):
left, right = [i.get_folded_value() for i in (node.left, node.right)]
if type(left) is not type(right):
raise UnfoldableNode("invalid operation", node)
if not isinstance(left, vy_ast.Num):
raise UnfoldableNode("not a number!", node.left)

# this validation is performed to prevent the compiler from hanging
# on very large shifts and improve the error message for negative
# values.
if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)) and not (0 <= right.value <= 256):
raise InvalidLiteral("Shift bits must be between 0 and 256", node.right)

value = node.op._op(left.value, right.value)
return type(left).from_node(node, value=value)

def visit_BoolOp(self, node):
values = [v.get_folded_value() for v in node.values]

if any(not isinstance(v, vy_ast.NameConstant) for v in values):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")

values = [v.value for v in values]
value = node.op._op(values)
return vy_ast.NameConstant.from_node(node, value=value)

def visit_Compare(self, node):
left, right = [i.get_folded_value() for i in (node.left, node.right)]
if not isinstance(left, vy_ast.Constant):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")

# CMC 2022-08-04 we could probably remove these evaluation rules as they
# are taken care of in the IR optimizer now.
if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)):
if not isinstance(right, vy_ast.List):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")
if next((i for i in right.elements if not isinstance(i, vy_ast.Constant)), None):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")
if len(set([type(i) for i in right.elements])) > 1:
raise UnfoldableNode("List contains multiple literal types")
value = node.op._op(left.value, [i.value for i in right.elements])
return vy_ast.NameConstant.from_node(node, value=value)

if not isinstance(left, type(right)):
raise UnfoldableNode("Cannot compare different literal types")

# this is maybe just handled in the type checker.
if not isinstance(node.op, (vy_ast.Eq, vy_ast.NotEq)) and not isinstance(left, vy_ast.Num):
raise UnfoldableNode(
f"Invalid literal types for {node.op.description} comparison", node
)

value = node.op._op(left.value, right.value)
return vy_ast.NameConstant.from_node(node, value=value)

def visit_List(self, node) -> vy_ast.ExprNode:
elements = [e.get_folded_value() for e in node.elements]
return type(node).from_node(node, elements=elements)

def visit_Tuple(self, node) -> vy_ast.ExprNode:
elements = [e.get_folded_value() for e in node.elements]
return type(node).from_node(node, elements=elements)

def visit_Dict(self, node) -> vy_ast.ExprNode:
values = [v.get_folded_value() for v in node.values]
return type(node).from_node(node, values=values)

def visit_Call(self, node) -> vy_ast.ExprNode:
if not isinstance(node.func, vy_ast.Name):
raise UnfoldableNode("not a builtin", node)

namespace = get_namespace()

func_name = node.func.id
if func_name not in namespace:
raise UnfoldableNode("unknown", node)

typ = namespace[func_name]
# TODO: rename to vyper_type.try_fold_call_expr
if not hasattr(typ, "_try_fold"):
raise UnfoldableNode("unfoldable", node)
return typ._try_fold(node)

def visit_Subscript(self, node) -> vy_ast.ExprNode:
slice_ = node.slice.value.get_folded_value()
value = node.value.get_folded_value()

if not isinstance(value, vy_ast.List):
raise UnfoldableNode("Subscript object is not a literal list")

elements = value.elements
if len(set([type(i) for i in elements])) > 1:
raise UnfoldableNode("List contains multiple node types")

if not isinstance(slice_, vy_ast.Int):
raise UnfoldableNode("invalid index type", slice_)

idx = slice_.value
if idx < 0 or idx >= len(elements):
raise UnfoldableNode("invalid index value")

return elements[idx]
12 changes: 2 additions & 10 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
@@ -25,12 +25,7 @@
from vyper.semantics.analysis.common import VyperNodeVisitorBase
from vyper.semantics.analysis.import_graph import ImportGraph
from vyper.semantics.analysis.local import ExprVisitor, validate_functions
from vyper.semantics.analysis.pre_typecheck import pre_typecheck
from vyper.semantics.analysis.utils import (
check_modifiability,
get_exact_type_from_node,
validate_expected_type,
)
from vyper.semantics.analysis.utils import check_modifiability, get_exact_type_from_node
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace
from vyper.semantics.types import EventT, FlagT, InterfaceT, StructT
@@ -55,8 +50,6 @@ def validate_semantics_r(
"""
validate_literal_nodes(module_ast)

pre_typecheck(module_ast)

# validate semantics and annotate AST with type/semantics information
namespace = get_namespace()

@@ -315,12 +308,11 @@ def _validate_self_namespace():
if node.is_constant:
assert node.value is not None # checked in VariableDecl.validate()

ExprVisitor().visit(node.value, type_)
ExprVisitor().visit(node.value, type_) # performs validate_expected_type

if not check_modifiability(node.value, Modifiability.CONSTANT):
raise StateAccessViolation("Value must be a literal", node.value)

validate_expected_type(node.value, type_)
_validate_self_namespace()

return _finalize()
94 changes: 0 additions & 94 deletions vyper/semantics/analysis/pre_typecheck.py

This file was deleted.

2 changes: 2 additions & 0 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
@@ -650,6 +650,8 @@ def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) ->
return all(check_modifiability(v, modifiability) for v in args[0].values)

call_type = get_exact_type_from_node(node.func)

# builtins
call_type_modifiability = getattr(call_type, "_modifiability", Modifiability.MODIFIABLE)
return call_type_modifiability >= modifiability

16 changes: 15 additions & 1 deletion vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,12 @@
from vyper import ast as vy_ast
from vyper.abi_types import ABI_Address, ABIType
from vyper.ast.validation import validate_call_args
from vyper.exceptions import InterfaceViolation, NamespaceCollision, StructureException
from vyper.exceptions import (
InterfaceViolation,
NamespaceCollision,
StructureException,
UnfoldableNode,
)
from vyper.semantics.analysis.base import VarInfo
from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids
from vyper.semantics.namespace import get_namespace
@@ -53,6 +58,15 @@ def abi_type(self) -> ABIType:
def __repr__(self):
return f"interface {self._id}"

def _try_fold(self, node):
if len(node.args) != 1:
raise UnfoldableNode("wrong number of args", node.args)
args = [arg.get_folded_value() for arg in node.args]
if not isinstance(args[0], vy_ast.Hex):
raise UnfoldableNode("not an address", node.args[0])

return node

# when using the type itself (not an instance) in the call position
def _ctor_call_return(self, node: vy_ast.Call) -> "InterfaceT":
self._ctor_arg_types(node)
11 changes: 11 additions & 0 deletions vyper/semantics/types/user.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
InvalidAttribute,
NamespaceCollision,
StructureException,
UnfoldableNode,
UnknownAttribute,
VariableDeclarationException,
)
@@ -357,6 +358,16 @@ def from_StructDef(cls, base_node: vy_ast.StructDef) -> "StructT":
def __repr__(self):
return f"{self._id} declaration object"

def _try_fold(self, node):
if len(node.args) != 1:
raise UnfoldableNode("wrong number of args", node.args)
args = [arg.get_folded_value() for arg in node.args]
if not isinstance(args[0], vy_ast.Dict):
raise UnfoldableNode("not a dict")

# it can't be reduced, but this lets upstream code know it's constant
return node

@property
def size_in_bytes(self):
return sum(i.size_in_bytes for i in self.member_types.values())