Skip to content

Commit

Permalink
Print help on run_mosaic_trainer.py, cleaned up verbosity. (#170)
Browse files Browse the repository at this point in the history
1. Moved the `datadir` onto trainer hparams to avoid an UnusedArgumentWarning
2. Setting the default loglevel to warning
3. Cleaned up warnings to not print a useless extra second line showing the source code of the warning.warn call
4. Cleaned up the launch script to be not-very-verbose by default; but added a -v option for verbosity.
5. If `run_mosaic_trainer.py` is run without args, then the help is printed.

Closes #88, #128
  • Loading branch information
ravi-mosaicml authored Dec 17, 2021
1 parent 0671e7a commit 722edf0
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 49 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ from composer import trainer, algorithms, Trainer

trainer_hparams = trainer.load("resnet50")
trainer_hparams.algorithms = algorithms.load_multiple("squeeze_excite", "scale_schedule")
trainer_hparams.set_datadir('your/dataset/path/')
trainer_hparams.datadir = 'your/dataset/path/'

learner = Trainer.create_from_hparams(hparams=trainer_hparams)
learner.fit()
Expand Down
48 changes: 30 additions & 18 deletions composer/cli/launcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2021 MosaicML. All Rights Reserved.

import datetime
import logging
import os
import signal
import socket
Expand All @@ -14,6 +15,8 @@

CLEANUP_TIMEOUT = datetime.timedelta(seconds=30)

log = logging.getLogger(__name__)


def get_parser():
parser = ArgumentParser(description="Utility for launching distributed machine learning jobs.")
Expand Down Expand Up @@ -56,6 +59,7 @@ def get_parser():
"--module_mode",
action="store_true",
help="If set, run the training script as a module instead of as a script.")
parser.add_argument("-v", "--verbose", action="store_true", help="If set, print verbose messages")
parser.add_argument("training_script",
type=str,
help="The path to the training script used to initialize a single training "
Expand Down Expand Up @@ -90,7 +94,7 @@ def parse_args():
def launch_processes(nproc: int, world_size: int, base_rank: int, master_addr: str, master_port: Optional[int],
module_mode: bool, run_directory: Optional[str], training_script: str,
training_script_args: List[Any]) -> Set[subprocess.Popen]:
print(f"Starting DDP on local node for global_rank({base_rank}-{base_rank+nproc-1})")
log.info("Starting DDP on local node for global_rank(%s-%s)", base_rank, base_rank + nproc - 1)
processes = []

if run_directory is None:
Expand All @@ -104,7 +108,7 @@ def launch_processes(nproc: int, world_size: int, base_rank: int, master_addr: s
"This may lead to race conditions when launching multiple training processes simultaneously. "
"To eliminate this race condition, explicitely specify a port with --master_port PORT_NUMBER")
master_port = get_free_tcp_port()
print(f"DDP Store: tcp://{master_addr}:{master_port}")
log.info("DDP Store: tcp://%s:%s", master_addr, master_port)

for local_rank in range(nproc):
global_rank = base_rank + local_rank
Expand All @@ -122,7 +126,7 @@ def launch_processes(nproc: int, world_size: int, base_rank: int, master_addr: s
current_env["MASTER_PORT"] = str(master_port)
current_env["RUN_DIRECTORY"] = run_directory

print(f"Launching process for local_rank({local_rank}), global_rank({global_rank})")
log.info("Launching process for local_rank(%s), global_rank(%s)", local_rank, global_rank)

if local_rank == 0:
process = subprocess.Popen(cmd, env=current_env, text=True)
Expand Down Expand Up @@ -150,12 +154,12 @@ def monitor_processes(processes: Set[subprocess.Popen]):
# return code of -9 implies sigkill, presumably from cleanup_processes()
if process.returncode not in (0, -9):
if process.stdout is None:
output = ""
output = None
else:
output = process.stdout.read()

if process.stderr is None:
stderr = ""
stderr = None
else:
stderr = process.stderr.read()
exc = subprocess.CalledProcessError(
Expand All @@ -164,15 +168,20 @@ def monitor_processes(processes: Set[subprocess.Popen]):
output=output,
stderr=stderr,
)
error_msg = [
"Error in subprocess",
"----------Subprocess STDOUT----------",
exc.output,
"----------Subprocess STDERR----------",
exc.stderr,
]
error_msg = [f"Process {process.pid} excited with code {process.returncode}"]
if output is not None:
error_msg.extend([
"----------Begin subprocess STDOUT----------",
output,
"----------End subprocess STDOUT----------",
])
if stderr is not None:
error_msg.extend([
"----------Begin subprocess STDERR----------",
exc.stderr,
"----------End subprocess STDERR----------",
])
print("\n".join(error_msg))
print(exc)
sys.exit(process.returncode)
else:
# exited cleanly
Expand All @@ -185,14 +194,14 @@ def cleanup_processes(processes: Set[subprocess.Popen]):
for process in processes:
process.poll()
if process.returncode is None:
print(f"Killing subprocess {process.pid} with SIGTERM")
log.info("Killing subprocess %s with SIGTERM", process.pid)
try:
os.killpg(process.pid, signal.SIGTERM)
except ProcessLookupError:
pass

current_time = datetime.datetime.now()
print(f"Waiting {CLEANUP_TIMEOUT.seconds} seconds for processes to terminate...")
print(f"Waiting up to {CLEANUP_TIMEOUT.seconds} seconds for all training processes to terminate...")
while datetime.datetime.now() - current_time < CLEANUP_TIMEOUT:
for process in processes:
process.poll()
Expand All @@ -203,7 +212,7 @@ def cleanup_processes(processes: Set[subprocess.Popen]):
for process in processes:
process.poll()
if process.returncode is None:
print(f"Failed to kill subprocess {process.pid} with SIGTERM; using SIGKILL instead")
log.warn("Failed to kill subprocess %s with SIGTERM; using SIGKILL instead", process.pid)
try:
os.killpg(process.pid, signal.SIGKILL)
except ProcessLookupError:
Expand All @@ -214,7 +223,7 @@ def aggregate_process_returncode(processes: Set[subprocess.Popen]) -> int:
for process in processes:
process.poll()
if process.returncode is None:
print(f"Subprocess {process.pid} has still not exited; return exit code 1.")
log.warn("Subprocess %s has still not exited; return exit code 1.", process.pid)
return 1
if process.returncode != 0:
return process.returncode
Expand All @@ -225,6 +234,9 @@ def aggregate_process_returncode(processes: Set[subprocess.Popen]) -> int:
def main():
args = parse_args()

logging.basicConfig()
log.setLevel(logging.INFO if args.verbose else logging.WARN)

processes = launch_processes(nproc=args.nproc,
world_size=args.world_size,
base_rank=args.base_rank,
Expand All @@ -238,7 +250,7 @@ def main():
try:
monitor_processes(processes)
except KeyboardInterrupt:
print("Caught Ctrl+C; killing processes")
print("Caught Ctrl+C; killing training processes")
raise
finally:
cleanup_processes(processes)
Expand Down
6 changes: 6 additions & 0 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ def create_from_hparams(cls, hparams: TrainerHparams) -> Trainer:
"""

hparams.validate()
import composer
logging.getLogger(composer.__name__).setLevel(hparams.log_level)

# devices and systems
device = hparams.device.initialize_object()
Expand All @@ -337,6 +339,10 @@ def create_from_hparams(cls, hparams: TrainerHparams) -> Trainer:
dict_config = hparams.to_dict()
log_destinations = [x.initialize_object(config=dict_config) for x in hparams.loggers]

if hparams.datadir is not None:
hparams.train_dataset.datadir = hparams.datadir
hparams.val_dataset.datadir = hparams.datadir

train_device_batch_size = hparams.train_batch_size // ddp.get_world_size()
if hparams.train_dataset.shuffle and hparams.train_subset_num_batches:
warnings.warn(
Expand Down
15 changes: 5 additions & 10 deletions composer/trainer/trainer_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,11 @@ class TrainerHparams(hp.Hparams):
default=False)

compute_training_metrics: bool = hp.optional(doc="Log validation metrics on training data", default=False)
log_level: str = hp.optional(doc="Python loglevel to use composer", default="INFO")
log_level: str = hp.optional(doc="Python loglevel to use composer", default="WARNING")
datadir: Optional[str] = hp.optional(doc=textwrap.dedent("""
Datadir to apply for both the training and validation datasets. If specified,
it will override train_dataset.datadir and val_dataset.datadir"""),
default=None)

def validate(self):
super().validate()
Expand Down Expand Up @@ -226,15 +230,6 @@ def initialize_object(self) -> Trainer:
from composer.trainer.trainer import Trainer
return Trainer.create_from_hparams(hparams=self)

def set_datadir(self, datadir: str) -> None:
"""Override the ``datadir`` property in the :attr:`train_dataset` and :attr:`val_dataset`.
Args:
datadir (str): The datadir
"""
self.train_dataset.datadir = datadir
self.val_dataset.datadir = datadir

@classmethod
def load(cls, model: str) -> TrainerHparams:
model_hparams_file = os.path.join(
Expand Down
4 changes: 2 additions & 2 deletions docs/source/getting_started/using_composer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Here are several ways to use the trainer:
# edit other properties in the hparams object
hparams.precision = Precision.FP32
hparams.grad_accum = 2
hparams.set_datadir("~/datasets")
hparams.datadir = "~/datasets"
trainer = Trainer.create_from_hparams(hparams)
trainer.fit()
Expand All @@ -97,7 +97,7 @@ Here are several ways to use the trainer:
from composer.trainer import TrainerHparams, Trainer
hparams = TrainerHparams.create('composer/yamls/models/classify_mnist_cpu.yaml')
hparams.set_datadir("~/datasets")
hparams.datadir = "~/datasets"
trainer = Trainer.create_from_hparams(hparams)
trainer.fit()
Expand Down
2 changes: 1 addition & 1 deletion examples/composer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"trainer_hparams.algorithms = algorithms.load_multiple(\n",
" \"blurpool\",\n",
" \"scale_schedule\")\n",
"trainer_hparams.set_datadir(\"~/datasets\")"
"trainer_hparams.datadir = \"~/datasets\""
]
},
{
Expand Down
27 changes: 10 additions & 17 deletions examples/run_mosaic_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,26 @@
--algorithms label_smoothing --alpha 0.1
--datadir ~/datasets
"""
import argparse
import logging
import sys
import warnings
from typing import Type

import composer
from composer.trainer.trainer import Trainer
from composer.trainer.trainer_hparams import TrainerHparams

logger = logging.getLogger(__name__)

def warning_on_one_line(message: str, category: Type[Warning], filename: str, lineno: int, file=None, line=None):
# From https://stackoverflow.com/questions/26430861/make-pythons-warnings-warn-not-mention-itself
return f'{category.__name__}: {message} (source: {filename}:{lineno})\n'


def main() -> None:
logging.basicConfig()
logging.captureWarnings(True)
warnings.formatwarning = warning_on_one_line

parser = argparse.ArgumentParser(parents=[TrainerHparams.get_argparse(cli_args=True)])
parser.add_argument(
'--datadir',
default=None,
help='set the datadir for both train and eval datasets',
)
if len(sys.argv) == 1:
sys.argv = [sys.argv[0], "--help"]

args, _ = parser.parse_known_args()
hparams = TrainerHparams.create(cli_args=True) # reads cli args from sys.argv
logging.getLogger(composer.__name__).setLevel(hparams.log_level)
if args.datadir is not None:
hparams.set_datadir(args.datadir)
logger.info(f'Set dataset dirs in hparams to: {args.datadir}')
trainer = Trainer.create_from_hparams(hparams=hparams)
trainer.fit()

Expand Down

0 comments on commit 722edf0

Please sign in to comment.