diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index 5ad00a2942..7ad8ff20e1 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -279,8 +279,11 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context # Because of how the code generator works Scalars can not be return values. # TODO: Remove this limitation as the CompiledSDFG contains logic for that. - if isinstance(desc, dt.Scalar) and name.startswith("__return") and not desc.transient: - raise InvalidSDFGError(f'Can not use scalar "{name}" as return value.', sdfg, None) + if (sdfg.parent is None and isinstance(desc, dt.Scalar) and name.startswith("__return") + and not desc.transient): + raise InvalidSDFGError( + f'Cannot use scalar data descriptor ("{name}") as return value of a top-level function.', sdfg, + None) # Validate array names if name is not None and not dtypes.validate_name(name): @@ -332,7 +335,6 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context symbols[str(sym)] = sym.dtype validate_control_flow_region(sdfg, sdfg, initialized_transients, symbols, references, **context) - except InvalidSDFGError as ex: # If the SDFG is invalid, save it fpath = os.path.join('_dacegraphs', 'invalid.sdfgz') @@ -340,6 +342,7 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context ex.path = fpath raise + def _accessible(sdfg: 'dace.sdfg.SDFG', container: str, context: Dict[str, bool]): """ Helper function that returns False if a data container cannot be accessed in the current SDFG context. diff --git a/tests/python_frontend/return_value_test.py b/tests/python_frontend/return_value_test.py index 93870c41ce..4a845bea0b 100644 --- a/tests/python_frontend/return_value_test.py +++ b/tests/python_frontend/return_value_test.py @@ -12,6 +12,19 @@ def return_scalar(): assert return_scalar() == 5 +def test_return_scalar_in_nested_function(): + + @dace.program + def nested_function() -> dace.int32: + return 5 + + @dace.program + def return_scalar(): + return nested_function() + + assert return_scalar() == 5 + + def test_return_array(): @dace.program @@ -91,6 +104,7 @@ def return_void(a: dace.float64[20]): if __name__ == '__main__': test_return_scalar() + test_return_scalar_in_nested_function() test_return_array() test_return_tuple() test_return_array_tuple()