Skip to content

Commit

Permalink
Refine condition for prohibiting scalar return values (#1838)
Browse files Browse the repository at this point in the history
The check disallowed valid cases where return values that are not
returned to the Python frontend were disallowed. Scalar return values in
nested functions take the form of a reference, which is not supported
with our current Python bindings.
  • Loading branch information
tbennun authored Dec 30, 2024
1 parent dbc7747 commit 1b25eb7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
9 changes: 6 additions & 3 deletions dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,11 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context

# Because of how the code generator works Scalars can not be return values.
# TODO: Remove this limitation as the CompiledSDFG contains logic for that.
if isinstance(desc, dt.Scalar) and name.startswith("__return") and not desc.transient:
raise InvalidSDFGError(f'Can not use scalar "{name}" as return value.', sdfg, None)
if (sdfg.parent is None and isinstance(desc, dt.Scalar) and name.startswith("__return")
and not desc.transient):
raise InvalidSDFGError(
f'Cannot use scalar data descriptor ("{name}") as return value of a top-level function.', sdfg,
None)

# Validate array names
if name is not None and not dtypes.validate_name(name):
Expand Down Expand Up @@ -332,14 +335,14 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context
symbols[str(sym)] = sym.dtype
validate_control_flow_region(sdfg, sdfg, initialized_transients, symbols, references, **context)


except InvalidSDFGError as ex:
# If the SDFG is invalid, save it
fpath = os.path.join('_dacegraphs', 'invalid.sdfgz')
sdfg.save(fpath, exception=ex, compress=True)
ex.path = fpath
raise


def _accessible(sdfg: 'dace.sdfg.SDFG', container: str, context: Dict[str, bool]):
"""
Helper function that returns False if a data container cannot be accessed in the current SDFG context.
Expand Down
14 changes: 14 additions & 0 deletions tests/python_frontend/return_value_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,19 @@ def return_scalar():
assert return_scalar() == 5


def test_return_scalar_in_nested_function():

@dace.program
def nested_function() -> dace.int32:
return 5

@dace.program
def return_scalar():
return nested_function()

assert return_scalar() == 5


def test_return_array():

@dace.program
Expand Down Expand Up @@ -91,6 +104,7 @@ def return_void(a: dace.float64[20]):

if __name__ == '__main__':
test_return_scalar()
test_return_scalar_in_nested_function()
test_return_array()
test_return_tuple()
test_return_array_tuple()
Expand Down

0 comments on commit 1b25eb7

Please sign in to comment.