-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexperimental_compact_network.py
87 lines (70 loc) · 2.55 KB
/
experimental_compact_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import ast
import logging
import math
import os
import sys
from itertools import chain
import numpy as np
import torch
from fairseq import checkpoint_utils, options, scoring, tasks, utils
from fairseq.logging import progress_bar
from fairseq.logging.meters import StopwatchMeter, TimeMeter
def main(args):
assert args.path is not None, "--path required for generation!"
assert (
not args.sampling or args.nbest == args.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
args.replace_unk is None or args.dataset_impl == "raw"
), "--replace-unk requires a raw text dataset (--dataset-impl=raw)"
if args.results_path is not None:
os.makedirs(args.results_path, exist_ok=True)
output_path = os.path.join(
args.results_path, "generate-{}.txt".format(args.gen_subset)
)
with open(output_path, "w", buffering=1, encoding="utf-8") as h:
return _main(args, h)
else:
return _main(args, sys.stdout)
def _main(args, output_file):
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=output_file,
)
logger = logging.getLogger("fairseq_cli.generate")
utils.import_user_module(args)
if args.max_tokens is None and args.batch_size is None:
args.max_tokens = 12000
logger.info(args)
if args.seed is not None and not args.no_seed_provided:
np.random.seed(args.seed)
utils.set_torch_seed(args.seed)
use_cuda = torch.cuda.is_available() and not args.cpu
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
try:
src_dict = getattr(task, "source_dictionary", None)
except NotImplementedError:
src_dict = None
tgt_dict = task.target_dictionary
print('---------')
overrides = ast.literal_eval(args.model_overrides)
print('---------', overrides)
logger.info("loading model(s) from {}".format(args.path))
models, _model_args = checkpoint_utils.load_model_ensemble(
utils.split_paths(args.path),
arg_overrides=overrides,
task=task,
suffix=getattr(args, "checkpoint_suffix", ""),
strict=(args.checkpoint_shard_count == 1),
num_shards=args.checkpoint_shard_count,
)
print(_model_args)
def cli_main():
parser = options.get_experimental_generation_parser()
args = options.parse_args_and_arch(parser)
main(args)
if __name__ == "__main__":
cli_main()