Skip to content

Commit

Permalink
fix[lang]: builtin type comparisons (vyperlang#3956)
Browse files Browse the repository at this point in the history
prior to this commit, due to re-loading and re-analysing the builtin
modules, the types defined in those modules do not compare when they are
imported in different files. this commit globally(!) caches the builtin
modules, ensuring builtin types compare correctly no matter how they are
imported. the global scope should be safe, since builtins are always
stable across compilations.
  • Loading branch information
charles-cooper authored and electriclilies committed Apr 27, 2024
1 parent 833a55d commit 419ee8c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 5 deletions.
29 changes: 29 additions & 0 deletions tests/functional/codegen/modules/test_interface_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,32 @@ def test_foo(s: ifaces.IFoo) -> bool:
c = get_contract(contract, input_bundle=input_bundle)

assert c.test_foo(foo.address) is True


def test_import_interface_types_stability(make_input_bundle, get_contract):
lib1 = """
from ethereum.ercs import IERC20
"""
lib2 = """
from ethereum.ercs import IERC20
"""

main = """
import lib1
import lib2
from ethereum.ercs import IERC20
@external
def foo() -> bool:
# check that this typechecks both directions
a: lib1.IERC20 = IERC20(msg.sender)
b: lib2.IERC20 = IERC20(msg.sender)
# return the equality so we can sanity check it
return a == b
"""
input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2})
c = get_contract(main, input_bundle=input_bundle)

assert c.foo() is True
20 changes: 16 additions & 4 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
FileInput,
FilesystemInputBundle,
InputBundle,
PathLike,
)
from vyper.evm.opcodes import version_check
from vyper.exceptions import (
Expand Down Expand Up @@ -398,7 +399,7 @@ def _ast_from_file(self, file: FileInput) -> vy_ast.Module:
# two ASTs produced from the same source
ast_of = self.input_bundle._cache._ast_of
if file.source_id not in ast_of:
ast_of[file.source_id] = _parse_and_fold_ast(file)
ast_of[file.source_id] = _parse_ast(file)

return ast_of[file.source_id]

Expand Down Expand Up @@ -870,7 +871,7 @@ def _load_import_helper(
raise ModuleNotFound(module_str, hint=hint) from err


def _parse_and_fold_ast(file: FileInput) -> vy_ast.Module:
def _parse_ast(file: FileInput) -> vy_ast.Module:
module_path = file.resolved_path # for error messages
try:
# try to get a relative path, to simplify the error message
Expand Down Expand Up @@ -910,6 +911,9 @@ def _is_builtin(module_str):
return any(module_str.startswith(prefix) for prefix in BUILTIN_PREFIXES)


_builtins_cache: dict[PathLike, tuple[CompilerInput, ModuleT]] = {}


def _load_builtin_import(level: int, module_str: str) -> tuple[CompilerInput, InterfaceT]:
if not _is_builtin(module_str):
raise ModuleNotFound(module_str)
Expand All @@ -933,6 +937,13 @@ def _load_builtin_import(level: int, module_str: str) -> tuple[CompilerInput, In

path = _import_to_path(level, remapped_module).with_suffix(".vyi")

# builtins are globally the same, so we can safely cache them
# (it is also *correct* to cache them, so that types defined in builtins
# compare correctly using pointer-equality.)
if path in _builtins_cache:
file, module_t = _builtins_cache[path]
return file, module_t.interface

try:
file = input_bundle.load_file(path)
assert isinstance(file, FileInput) # mypy hint
Expand All @@ -946,9 +957,10 @@ def _load_builtin_import(level: int, module_str: str) -> tuple[CompilerInput, In
hint = f"try renaming `{module_prefix}` to `I{module_prefix}`"
raise ModuleNotFound(module_str, hint=hint) from e

# TODO: it might be good to cache this computation
interface_ast = _parse_and_fold_ast(file)
interface_ast = _parse_ast(file)

with override_global_namespace(Namespace()):
module_t = _analyze_module_r(interface_ast, input_bundle, ImportGraph(), is_interface=True)

_builtins_cache[path] = file, module_t
return file, module_t.interface
2 changes: 1 addition & 1 deletion vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _is_function_implemented(fn_name, fn_type):
continue

if not _is_function_implemented(name, type_):
unimplemented.append(name)
unimplemented.append(type_._pp_signature)

if len(unimplemented) > 0:
# TODO: improve the error message for cases where the
Expand Down

0 comments on commit 419ee8c

Please sign in to comment.