diff --git a/src/accelerate/commands/test.py b/src/accelerate/commands/test.py index 41da7559679..69045371496 100644 --- a/src/accelerate/commands/test.py +++ b/src/accelerate/commands/test.py @@ -45,10 +45,12 @@ def test_command_parser(subparsers=None): def test_command(args): script_name = os.path.sep.join(__file__.split(os.path.sep)[:-2] + ["test_utils", "scripts", "test_script.py"]) - test_args = f""" - --config_file={args.config_file} {script_name} - """.split() - cmd = ["accelerate-launch"] + test_args + if args.config_file is None: + test_args = script_name + else: + test_args = f"--config_file={args.config_file} {script_name}" + + cmd = ["accelerate-launch"] + test_args.split() result = execute_subprocess_async(cmd, env=os.environ.copy()) if result.returncode == 0: print("Test is a success! You are ready for your distributed training!")