diff --git a/optimum_benchmark/backends/pytorch/config.py b/optimum_benchmark/backends/pytorch/config.py index ec48f639..61f9dfc0 100644 --- a/optimum_benchmark/backends/pytorch/config.py +++ b/optimum_benchmark/backends/pytorch/config.py @@ -5,7 +5,6 @@ from ...system_utils import is_rocm_system from ..config import BackendConfig -DEVICE_MAPS = ["auto", "sequential"] AMP_DTYPES = ["bfloat16", "float16"] TORCH_DTYPES = ["bfloat16", "float16", "float32", "auto"] @@ -60,9 +59,6 @@ def __post_init__(self): "Please remove it from the `model_kwargs` and set it in the backend config directly." ) - if self.device_map is not None and self.device_map not in DEVICE_MAPS: - raise ValueError(f"`device_map` must be one of {DEVICE_MAPS}. Got {self.device_map} instead.") - if self.torch_dtype is not None and self.torch_dtype not in TORCH_DTYPES: raise ValueError(f"`torch_dtype` must be one of {TORCH_DTYPES}. Got {self.torch_dtype} instead.")