Skip to content

Commit

Permalink
fix for issue #11, some codepaths were invoking `ScheduleImplMixin.sa…
Browse files Browse the repository at this point in the history
…ve_schedule` from non-`rank_zero_only` guarded contexts. Explicitly guarding `save_schedule` as well as `gen_ft_schedule` themselves to ensure similar bugs surface during development if those functions are directly accessed in future non-`rank_zero_only` contexts. Includes associated test enhancements.
  • Loading branch information
speediedan committed Dec 20, 2023
1 parent 87e6cf7 commit ec348db
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/finetuning_scheduler/fts.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
if self.gen_ft_sched_only:
if trainer.is_global_zero:
assert trainer.log_dir is not None
_ = self.gen_ft_schedule(pl_module, trainer.log_dir)
_ = ScheduleImplMixin.gen_ft_schedule(pl_module, trainer.log_dir)
log.info("Bypassing training, generating fine-tuning schedule for review and subsequent fine-tuning")
raise SystemExit(0)
if not self.epoch_transitions_only:
Expand Down
7 changes: 5 additions & 2 deletions src/finetuning_scheduler/fts_supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,10 +1357,12 @@ def gen_implicit_schedule(self, sched_dir: Union[str, os.PathLike]) -> None:
sched_dir: directory to which the generated schedule should be written. By default will be
``Trainer.log_dir``.
"""
default_ft_schedule = self.gen_ft_schedule(self.pl_module, sched_dir)
default_ft_schedule = ScheduleImplMixin.gen_ft_schedule(self.pl_module, sched_dir)
assert default_ft_schedule is not None
rank_zero_info(f"Generated default fine-tuning schedule '{default_ft_schedule}' for iterative fine-tuning")
self.ft_schedule = self.load_yaml_schedule(default_ft_schedule)

@rank_zero_only
@staticmethod
def save_schedule(schedule_name: str, layer_config: Dict, dump_loc: Union[str, os.PathLike]) -> os.PathLike:
"""Save loaded or generated schedule to a directory to ensure reproducability.
Expand All @@ -1385,8 +1387,9 @@ def save_schedule(schedule_name: str, layer_config: Dict, dump_loc: Union[str, o
rank_zero_info(f"fine-tuning schedule dumped to {ft_schedule_yaml}.")
return ft_schedule_yaml

@rank_zero_only
@staticmethod
def gen_ft_schedule(module: Module, dump_loc: Union[str, os.PathLike]) -> os.PathLike:
def gen_ft_schedule(module: Module, dump_loc: Union[str, os.PathLike]) -> Optional[os.PathLike]:
"""Generate the default fine-tuning schedule using a naive, 2-parameters per-level heuristic.
Args:
Expand Down
46 changes: 38 additions & 8 deletions tests/test_finetuning_scheduler_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# limitations under the License.
import os
import re
from tempfile import gettempdir
from collections import OrderedDict
from copy import deepcopy
from logging import DEBUG
Expand Down Expand Up @@ -50,6 +51,22 @@ def get_fts(trainer: "Trainer") -> Callback:
def nones(num_n) -> Tuple: # to help dedup config
return (None,) * num_n

DIST_TEST_SYMDIR = Path(gettempdir()) / "current_dist_test"

def manage_dist_test_symlink(src, dst=DIST_TEST_SYMDIR, overwrite=True):
"""Creates or updates our symlink for use with distributed tests.
Args:
src: The source path.
dst: The destination path.
overwrite: Whether to overwrite an existing symlink.
"""
if dst.exists() and not overwrite:
return
if dst.is_symlink() or dst.exists():
os.unlink(dst)
os.symlink(src, dst)
return dst

class AverageDataset(Dataset):
def __init__(self, dataset_len=300, sequence_len=100):
Expand Down Expand Up @@ -495,6 +512,14 @@ def ckpt_set(tmpdir_factory) -> Dict:
trainer.fit(model)
return {"best": trainer.checkpoint_callback.best_model_path, "kth": trainer.checkpoint_callback.kth_best_model_path}

def get_sched_fixture_tmpdir(tmpfactory_handle):
rank = getattr(rank_zero_only, "rank", 0)
if rank == 0:
tmpdir = tmpfactory_handle.getbasetemp()
_ = manage_dist_test_symlink(tmpdir)
else:
tmpdir = DIST_TEST_SYMDIR
return tmpdir, rank

@pytest.fixture(scope="function")
def boring_ft_schedule(tmpdir_factory) -> Tuple[Path, Dict]:
Expand All @@ -503,12 +528,14 @@ def boring_ft_schedule(tmpdir_factory) -> Tuple[Path, Dict]:
seed_everything(42)
callbacks = [FinetuningScheduler(gen_ft_sched_only=True)]
model = FinetuningSchedulerBoringModel()
tmpdir = tmpdir_factory.getbasetemp()
tmpdir, rank = get_sched_fixture_tmpdir(tmpdir_factory)
trainer = Trainer(default_root_dir=tmpdir, callbacks=callbacks, devices=1)
# unmod_schedule_file = tmpdir / "lightning_logs" / "version_0" / f"{model.__class__.__name__}_ft_schedule.yaml"]
unmod_schedule_file = Path(trainer.log_dir) / f"{model.__class__.__name__}_ft_schedule.yaml"
with pytest.raises(SystemExit):
trainer.fit(model)
unmod_schedule_file = tmpdir / "lightning_logs" / "version_0" / f"{model.__class__.__name__}_ft_schedule.yaml"
# N.B. Though we run this fixture for each rank to avoid adding special logic to each distributed client test, we
# only generate a schedule on rank 0, linking to it on the other ranks.
if rank == 0:
with pytest.raises(SystemExit):
trainer.fit(model)
mod_sched_dict = get_fts(trainer).load_yaml_schedule(unmod_schedule_file)
reinit_optim_sched_dict = deepcopy(mod_sched_dict)
reinitlr_sched_dict = deepcopy(mod_sched_dict)
Expand Down Expand Up @@ -2562,17 +2589,20 @@ def on_validation_epoch_end(self):


@RunIf(standalone=True, min_cuda_gpus=2)
def test_fts_multi_ddp(tmpdir):
@pytest.mark.parametrize("explicit_mode", [True, False], ids=["explicit", "implicit"])
def test_fts_multi_ddp(tmpdir, boring_ft_schedule, explicit_mode):
"""Validate :class:`~finetuning_scheduler.FinetuningScheduler` functions properly in a supported 'ddp'
distributed context."""
seed_everything(42)
ft_schedule = boring_ft_schedule[1] if explicit_mode else None
expected_depth = 2 if explicit_mode else 3
model = FinetuningSchedulerBoringModel()
callbacks = [FinetuningScheduler(), FTSEarlyStopping(monitor="val_loss", patience=1)]
callbacks = [FinetuningScheduler(ft_schedule=ft_schedule), FTSEarlyStopping(monitor="val_loss", patience=1)]
trainer = Trainer(default_root_dir=tmpdir, callbacks=callbacks, strategy="ddp", devices=2)
finetuningscheduler_callback = get_fts(trainer)
trainer.fit(model)
assert finetuningscheduler_callback.depth_remaining == 0
assert finetuningscheduler_callback.curr_depth == 3
assert finetuningscheduler_callback.curr_depth == expected_depth
assert finetuningscheduler_callback.curr_depth == finetuningscheduler_callback.max_depth


Expand Down
10 changes: 6 additions & 4 deletions tests/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
get_fts,
nones,
TestFinetuningScheduler,
get_sched_fixture_tmpdir,
)

_distributed_available = torch.distributed.is_available()
Expand Down Expand Up @@ -102,11 +103,12 @@ def fsdp_ft_schedules(tmpdir_factory) -> Tuple[Path, Dict]:
seed_everything(42)
callbacks = [FinetuningScheduler(gen_ft_sched_only=True), FTSCheckpoint(monitor="val_loss")]
model = FinetuningSchedulerBoringModel()
tmpdir = tmpdir_factory.getbasetemp()
tmpdir, rank = get_sched_fixture_tmpdir(tmpdir_factory)
trainer = Trainer(default_root_dir=tmpdir, callbacks=callbacks, devices=1)
unmod_schedule_file = Path(trainer.log_dir) / f"{model.__class__.__name__}_ft_schedule.yaml"
with pytest.raises(SystemExit):
trainer.fit(model)
unmod_schedule_file = tmpdir / "lightning_logs" / "version_0" / f"{model.__class__.__name__}_ft_schedule.yaml"
if rank == 0:
with pytest.raises(SystemExit):
trainer.fit(model)
mod_sched_dict = get_fts(trainer).load_yaml_schedule(unmod_schedule_file)
mod_sched_dict[0]["params"].extend(mod_sched_dict.pop(1)["params"])
mod_sched_dict[0]["max_transition_epoch"] = 3
Expand Down

0 comments on commit ec348db

Please sign in to comment.