Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shard eval dataset and aggregate eval metrics #10

Merged
merged 3 commits into from
Apr 2, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions examples/run_glue_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from __future__ import absolute_import, division, print_function

import argparse
from collections import defaultdict
import glob
import logging
import math
import multiprocessing
import os
import time
import random
import time

import numpy as np
import torch
Expand Down Expand Up @@ -205,18 +206,18 @@ def evaluate(args, model, tokenizer, prefix="", disable_logging=False):

# Loop to handle MNLI double evaluation (matched, mis-matched)
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
output_dir = '{}/eval-xla{}'.format(args.output_dir, xm.get_ordinal())
eval_outputs_dirs = (output_dir, output_dir + '-MM') if args.task_name == "mnli" else (output_dir,)
eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,)

results = {}
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
eval_sampler = get_sampler(eval_dataset)

if not os.path.exists(eval_output_dir):
os.makedirs(eval_output_dir)

# Note that we don't shard for TPU Multiprocess as we don't reduce loss among client processes.
dataloader = DataLoader(eval_dataset, batch_size=args.eval_batch_size, shuffle=False)
dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, shuffle=False)
eval_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)

# Eval!
Expand All @@ -238,9 +239,9 @@ def evaluate(args, model, tokenizer, prefix="", disable_logging=False):
# XLM, DistilBERT and RoBERTa don't use segment_ids
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None
outputs = model(**inputs)
tmp_eval_loss, logits = outputs[:2]
batch_eval_loss, logits = outputs[:2]

eval_loss += tmp_eval_loss.mean().item()
eval_loss += batch_eval_loss
nb_eval_steps += 1
if preds is None:
preds = logits.detach().cpu().numpy()
Expand All @@ -249,21 +250,31 @@ def evaluate(args, model, tokenizer, prefix="", disable_logging=False):
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)

# Get all predictions and labels from all workers
preds = xm.mesh_reduce('eval_preds', preds, np.concatenate)
out_label_ids = xm.mesh_reduce(
'eval_out_label_ids', out_label_ids, np.concatenate)

eval_loss = eval_loss / nb_eval_steps
if args.output_mode == "classification":
preds = np.argmax(preds, axis=1)
elif args.output_mode == "regression":
preds = np.squeeze(preds)
result = compute_metrics(eval_task, preds, out_label_ids)
results.update(result)
results['eval_loss'] = eval_loss.item()

output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(prefix))
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
tb_writer.add_scalar(key, result[key])
if xm.is_master_ordinal():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is everything being logged here already on cpu?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe let's add a comment? It's a subtle point that can be missed by code readers.

with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(prefix))
for key in sorted(results.keys()):
logger.info(" %s = %s", key, str(results[key]))
writer.write("%s = %s\n" % (key, str(results[key])))
tb_writer.add_scalar(key, results[key])

if args.metrics_debug:
xm.master_print(met.metrics_report())

tb_writer.close()
return results
Expand Down