diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index cfe2dc47..9ba67054 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -167,7 +167,7 @@ def get_train_args(override_sys_args: Optional[List[str]] = None) -> TrainConfig override_sys_args: The command line arguments. If None, sys.argv[1:] is used. This is mainly useful for testing. """ - args = override_sys_args or sys.argv[1:] + args = sys.argv[1:] if override_sys_args is None else override_sys_args return simple_parsing.parse( config_class=TrainConfig, diff --git a/ultravox/training/config_base_test.py b/ultravox/training/config_base_test.py new file mode 100644 index 00000000..d5481574 --- /dev/null +++ b/ultravox/training/config_base_test.py @@ -0,0 +1,7 @@ +from ultravox.training import config_base + + +def test_can_create_train_config(): + # override args to [], otherwise pytest arguments will be used + args = config_base.get_train_args([]) + assert isinstance(args, config_base.TrainConfig)