From d3689afeb779b91ca898f43a5b3f14e6518ef6e9 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Wed, 11 May 2022 07:57:55 -0700 Subject: [PATCH] Fix tests --- composer/trainer/trainer_hparams.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/composer/trainer/trainer_hparams.py b/composer/trainer/trainer_hparams.py index 2f437ff8ad..3cb4683f9c 100755 --- a/composer/trainer/trainer_hparams.py +++ b/composer/trainer/trainer_hparams.py @@ -55,6 +55,8 @@ # They exist purely for pyright and should never need __all__ = ["TrainerHparams", "FitHparams", "EvalHparams", "ExperimentHparams"] +Scheduler = Union[ComposerScheduler, PyTorchScheduler] + optimizer_registry = { "adam": AdamHparams, "adamw": AdamWHparams, @@ -350,9 +352,9 @@ class TrainerHparams(hp.Hparams): doc="Ratio by which to scale the training duration and learning rate schedules.", default=1.0, ) - step_schedulers_every_batch: bool = hp.optional( + step_schedulers_every_batch: Optional[bool] = hp.optional( doc="Whether schedulers will update after every optimizer step (True), or every epoch (False).", - default=True, + default=None, ) # Evaluation @@ -707,8 +709,7 @@ class FitKwargs(TypedDict): duration: Optional[Union[int, str, Time[int]]] # Schedulers - schedulers: Optional[Union[ComposerScheduler, PyTorchScheduler, Sequence[Union[ComposerScheduler, - PyTorchScheduler]]]] + schedulers: Optional[Union[Scheduler, Sequence[Scheduler]]] scale_schedule_ratio: float step_schedulers_every_batch: Optional[bool] @@ -777,9 +778,9 @@ class FitHparams(hp.Hparams): doc="Ratio by which to scale the training duration and learning rate schedules.", default=1.0, ) - step_schedulers_every_batch: bool = hp.optional( + step_schedulers_every_batch: Optional[bool] = hp.optional( doc="Whether schedulers will update after every optimizer step (True), or every epoch (False).", - default=True, + default=None, ) # Evaluation