diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/testing.py b/userbenchmark/dynamo/dynamobench/_dynamo/testing.py index 9281c7c7e2..406093119c 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/testing.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/testing.py @@ -382,6 +382,11 @@ def rand_strided( _T = TypeVar("_T") +def check_dynamic_shape_capture() -> bool: + # This also mirrors config from `test/dynamo/test_dynamic_shapes.py:make_dynamic_cls` + return not config.assume_static_by_default + + def _make_fn_with_patches(fn: Callable[..., _T], *patches: Any) -> Callable[..., _T]: @functools.wraps(fn) def _fn(*args: Any, **kwargs: Any) -> _T: