Skip to content

Commit

Permalink
[fix] Seeding fixes, seed random, generate strong seed (#76)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: https://github.com/fairinternal/pythia-internal/pull/76

Reviewed By: vedanuj

Differential Revision: D21105265

Pulled By: apsdehal

fbshipit-source-id: 4f991a81a9614bf4ff2536fe7aff6dac7e2b341a
  • Loading branch information
apsdehal committed May 8, 2020
1 parent d3ee615 commit 82fd8cd
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 23 deletions.
4 changes: 3 additions & 1 deletion mmf/configs/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ training:
trainer: 'base_trainer'
# Seed to be used for training. -1 means random seed between 1 and 100000.
# Either pass fixed through your config or command line arguments
seed: null
# Pass null to the seed if you don't want it seeded anyhow and
# want to leave it to default
seed: -1
# Name of the experiment, will be used while saving checkpoints
# and generating reports
experiment_name: run
Expand Down
4 changes: 2 additions & 2 deletions mmf/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def distributed_init(config):
warnings.warn("Distributed is already initialized, cannot initialize twice!")
else:
print(
"| distributed init (rank {}): {}".format(
"Distributed Init (Rank {}): {}".format(
config.distributed.rank, config.distributed.init_method
),
flush=True,
Expand All @@ -187,7 +187,7 @@ def distributed_init(config):
rank=config.distributed.rank,
)
print(
"| initialized host {} as rank {}".format(
"Initialized Host {} as Rank {}".format(
socket.gethostname(), config.distributed.rank
),
flush=True,
Expand Down
22 changes: 22 additions & 0 deletions mmf/utils/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
import random
from datetime import datetime

import numpy as np
import torch


def set_seed(seed):
if seed:
if seed == -1:
# From detectron2
seed = (
os.getpid()
+ int(datetime.now().strftime("%S%f"))
+ int.from_bytes(os.urandom(2), "big")
)
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)

return seed
29 changes: 9 additions & 20 deletions tools/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,26 @@
from mmf.utils.build import build_trainer
from mmf.utils.configuration import Configuration
from mmf.utils.distributed import distributed_init, infer_init_method
from mmf.utils.env import set_seed
from mmf.utils.flags import flags
from mmf.utils.general import setup_imports


def main(configuration, init_distributed=False):
setup_imports()
config = configuration.get_config()

if torch.cuda.is_available():
torch.cuda.set_device(config.device_id)
if config.seed:
if config.seed == -1:
config.seed = random.randint(10000, 20000)
np.random.seed(config.seed)
torch.manual_seed(config.seed)
# TODO: Re-enable after project
# random.seed(config.seed)
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True
torch.cuda.init()

if init_distributed:
distributed_init(config)

config.training.seed = set_seed(config.training.seed)
registry.register("seed", config.training.seed)
print("Using seed {}".format(config.training.seed))

trainer = build_trainer(configuration)
trainer.load()
trainer.train()
Expand Down Expand Up @@ -80,17 +80,6 @@ def run():
else:
config.device_id = 0
main(configuration)
# Log any errors that occur to log file
# try:
# trainer.load()
# trainer.train()
# except Exception as e:
# writer = getattr(trainer, "writer", None)

# if writer is not None:
# writer.write(e, "error", donot_print=True)
# if is_main_process():
# raise


if __name__ == "__main__":
Expand Down

0 comments on commit 82fd8cd

Please sign in to comment.