Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLNet] Parameters to reproduce SQuAD scores #947

Closed
astariul opened this issue Aug 2, 2019 · 18 comments
Closed

[XLNet] Parameters to reproduce SQuAD scores #947

astariul opened this issue Aug 2, 2019 · 18 comments
Labels

Comments

@astariul
Copy link
Contributor

astariul commented Aug 2, 2019

I'm trying to reproduce the results of XLNet-base on SQuAD 2.0.

From the README of XLNet :

Model RACE accuracy SQuAD1.1 EM SQuAD2.0 EM
BERT-Large 72.0 84.1 78.98
XLNet-Base 80.18
XLNet-Large 81.75 88.95 86.12

I ran the example with following hyper-parameters, on a single GPU P100 :

python ./examples/run_squad.py \
--model_type xlnet \
--model_name_or_path xlnet-base-cased \
--do_train \
--do_eval \
--train_file squad/train-v1.1.json \
--predict_file squad/dev-v1.1.json \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir ./finetuned_squad_xlnet \
--per_gpu_eval_batch_size 8 \
--per_gpu_train_batch_size 8 \
--save_steps 1000

And I got these results :

{
"exact": 72.88552507095554,
"f1": 80.81417081310839,
"total": 10570,
"HasAns_exact": 72.88552507095554,
"HasAns_f1": 80.81417081310839,
"HasAns_total": 10570
}

It's 8 points lower than the official results.

What are the parameters needed to reach same score as the official implementation ?


I open another issue than #822, because my results are not that much off.

@thomwolf
Copy link
Member

thomwolf commented Aug 5, 2019

Maybe we can use the same issue so the people following #822 can learn from your experiments as well?

@hlums
Copy link
Contributor

hlums commented Sep 27, 2019

I'm using xlnet-large-cased.
At first I got
{
"exact": 75.91296121097446,
"f1": 83.19559419987176,
"total": 10570,
"HasAns_exact": 75.91296121097446,
"HasAns_f1": 83.19559419987176,
"HasAns_total": 10570
}

Then I took a look at the XLNet repo and found the current preprocessing in transfomers is a little off. For the XLNet repo, they have P SEP Q SEP CLS, but the preprocessing code in this repo has CLS Q SEP P SEP. I tried to follow the XLNet repo preprocessing code and the hyper parameters in the paper and now I have
{
"exact": 84.37086092715232,
"f1": 92.01817406538726,
"total": 10570,
"HasAns_exact": 84.37086092715232,
"HasAns_f1": 92.01817406538726,
"HasAns_total": 10570
}

Here are my preprocessing code with the changes. Sorry it's a bit messy. I will create a PR next week.

# xlnet
cls_token = "[CLS]"
sep_token = "[SEP]"
pad_token = 0
sequence_a_segment_id = 0
sequence_b_segment_id = 1
cls_token_segment_id = 2
# Should this be 4, or it doesn't matter?
pad_token_segment_id = 3
cls_token_at_end = True
mask_padding_with_zero = True
# xlnet

qa_features = []

# unique_id identified unique feature/label pairs. It's different
# from qa_id in that each qa_example can be broken down into
# multiple feature samples if the paragraph length is longer than
# maximum sequence length allowed
query_tokens = tokenizer.tokenize(example.question_text)

if len(query_tokens) > max_question_length:
	query_tokens = query_tokens[0:max_question_length]
# map word-piece tokens to original tokens
tok_to_orig_index = []
# map original tokens to corresponding word-piece tokens
orig_to_tok_index = []
all_doc_tokens = []
for (i, token) in enumerate(example.doc_tokens):
	orig_to_tok_index.append(len(all_doc_tokens))
	sub_tokens = tokenizer.tokenize(token)
	for sub_token in sub_tokens:
		tok_to_orig_index.append(i)
		all_doc_tokens.append(sub_token)

tok_start_position = None
tok_end_position = None
if is_training and example.is_impossible:
	tok_start_position = -1
	tok_end_position = -1
if is_training and not example.is_impossible:
	tok_start_position = orig_to_tok_index[example.start_position]
	if example.end_position < len(example.doc_tokens) - 1:
		# +1: move the the token after the ending token in
		# original tokens
		# -1, moves one step back
		# these two operations ensures word piece is covered
		# when it's part of the original ending token.
		tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
	else:
		tok_end_position = len(all_doc_tokens) - 1
	(tok_start_position, tok_end_position) = _improve_answer_span(
		all_doc_tokens,
		tok_start_position,
		tok_end_position,
		tokenizer,
		example.orig_answer_text,
	)

# The -3 accounts for [CLS], [SEP] and [SEP]
max_tokens_for_doc = max_seq_len - len(query_tokens) - 3

# We can have documents that are longer than the maximum sequence length.
# To deal with this we do a sliding window approach, where we take chunks
# of the up to our max length with a stride of `doc_stride`.
_DocSpan = collections.namedtuple("DocSpan", ["start", "length"])
doc_spans = []
start_offset = 0
while start_offset < len(all_doc_tokens):
	length = len(all_doc_tokens) - start_offset
	if length > max_tokens_for_doc:
		length = max_tokens_for_doc
	doc_spans.append(_DocSpan(start=start_offset, length=length))
	if start_offset + length == len(all_doc_tokens):
		break
	start_offset += min(length, doc_stride)

for (doc_span_index, doc_span) in enumerate(doc_spans):
	if is_training:
		unique_id += 1
	else:
		unique_id += 2

	tokens = []
	token_to_orig_map = {}
	token_is_max_context = {}
	segment_ids = []

	# p_mask: mask with 1 for token than cannot be in the answer
	# (0 for token which can be in an answer)
	# Original TF implem also keep the classification token (set to 0), because
	# cls token represents prediction for unanswerable question
	p_mask = []

	# CLS token at the beginning
	if not cls_token_at_end:
		tokens.append(cls_token)
		segment_ids.append(cls_token_segment_id)
		p_mask.append(0)
		cls_index = 0


	# Paragraph
	for i in range(doc_span.length):
		split_token_index = doc_span.start + i
		token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]

		## TODO: maybe this can be improved to compute
		# is_max_context for each token only once.
		is_max_context = _check_is_max_context(doc_spans, doc_span_index, split_token_index)
		token_is_max_context[len(tokens)] = is_max_context
		tokens.append(all_doc_tokens[split_token_index])
		# xlnet
		# segment_ids.append(sequence_b_segment_id)
		segment_ids.append(sequence_a_segment_id)
		# xlnet ends
		p_mask.append(0)
	paragraph_len = doc_span.length

	# xlnet
	tokens.append(sep_token)
	segment_ids.append(sequence_a_segment_id)
	p_mask.append(1)

	tokens += query_tokens
	segment_ids += [sequence_b_segment_id] * len(query_tokens)
	p_mask += [1] * len(query_tokens)
	# xlnet ends

	# SEP token
	tokens.append(sep_token)
	segment_ids.append(sequence_b_segment_id)
	p_mask.append(1)

	# CLS token at the end
	if cls_token_at_end:
		tokens.append(cls_token)
		segment_ids.append(cls_token_segment_id)
		p_mask.append(0)
		cls_index = len(tokens) - 1  # Index of classification token

	input_ids = tokenizer.convert_tokens_to_ids(tokens)

	# The mask has 1 for real tokens and 0 for padding tokens. Only real
	# tokens are attended to.
	input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

	# Zero-pad up to the sequence length.
	if len(input_ids) < max_seq_len:
		pad_token_length = max_seq_len - len(input_ids)
		pad_mask = 0 if mask_padding_with_zero else 1
		input_ids += [pad_token] * pad_token_length
		input_mask += [pad_mask] * pad_token_length
		segment_ids += [pad_token_segment_id] * pad_token_length
		p_mask += [1] * pad_token_length

	assert len(input_ids) == max_seq_len
	assert len(input_mask) == max_seq_len
	assert len(segment_ids) == max_seq_len
	assert len(p_mask) == max_seq_len

	span_is_impossible = example.is_impossible
	start_position = None
	end_position = None
	if is_training and not span_is_impossible:
		# For training, if our document chunk does not contain an annotation
		# we throw it out, since there is nothing to predict.
		doc_start = doc_span.start
		doc_end = doc_span.start + doc_span.length - 1
		out_of_span = False
		if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
			out_of_span = True
		if out_of_span:
			start_position = 0
			end_position = 0
			span_is_impossible = True
		else:
			# +1 for [CLS] token
			# +1 for [SEP] token
			# xlnet
			# doc_offset = len(query_tokens) + 2
			doc_offset = 0
			# xlnet ends
			start_position = tok_start_position - doc_start + doc_offset
			end_position = tok_end_position - doc_start + doc_offset

	if is_training and span_is_impossible:
		start_position = cls_index
		end_position = cls_index
```

@pminervini
Copy link
Contributor

pminervini commented Oct 1, 2019

@hlums @colanim that's amazing, thank you! did you also experiment with SQuAD 2.0? I'm having issues training anything even remotely decent, and deciding whether to answer or not (NoAnswer) seems to be the problem.

@hlums
Copy link
Contributor

hlums commented Oct 1, 2019

@hlums @colanim that's amazing, thank you! did you also experiment with SQuAD 2.0? I'm having issues training anything even remotely decent, and deciding whether to answer or not (NoAnswer) seems to be the problem.

I haven't got a chance to try SQuAD 2.0. My guess is that since the CLS token is needed in SQuAD 2.0 to predict unanswerable questions, when the CLS token is misplaced, the impact on the model performance is bigger.

@thomwolf
Copy link
Member

thomwolf commented Oct 9, 2019

This is great @hlums! looking forward to a PR updating the example if you have time

@hlums
Copy link
Contributor

hlums commented Oct 11, 2019

Updating after I read comments in #1405 carefully.
I've created a local branch with my changes. I will validate it over the weekend.
I'm trying to push my branch to remote and got an access denied error.
This is how I cloned the repo
git clone https://hlums:<my personal access token>@github.com/huggingface/transformers/
Any one can help?

@pminervini
Copy link
Contributor

@hlums hey you can just fork this repo, make your changes in your version of the repo, and then do a pull request - that should work

@slayton58
Copy link
Contributor

slayton58 commented Oct 14, 2019 via email

@hlums
Copy link
Contributor

hlums commented Oct 14, 2019

Thanks for the clarification @slayton58! I figured it out after reading the comments in you PR more carefully. :)

@hlums
Copy link
Contributor

hlums commented Oct 14, 2019

Thank you guys! I solved the permission denied issue by git clone using ssh instead of https. Not sure why I never had this issue with my company's repos.
Anyway, I forked the repo (https://github.com/hlums/transformers) and pushed my changes to it.
However, I'm still having issue running the run_squad.py script. I'm getting "/data/anaconda/envs/py35/bin/python: Relative module names not supported"

Here are what I did

conda install pytorch
cd transformers
pip install --editable .
bash run_squad.sh

The content of my bash script is following

python -m ./examples/run_squad.py \
    --model_type xlnet \
    --model_name_or_path xlnet-large-cased \
    --do_train \
    --do_eval \
    --do_lower_case \
    --train_file /data/home/hlu/notebooks/NLP/examples/question_answering/train-v1.1.json \
    --predict_file /data/home/hlu/notebooks/NLP/examples/question_answering/dev-v1.1.json \
    --learning_rate 3e-5 \
    --num_train_epochs 2 \
    --max_seq_length 384 \
    --doc_stride 128 \
    --output_dir ./wwm_cased_finetuned_squad/ \
    --per_gpu_eval_batch_size=4  \
    --per_gpu_train_batch_size=4   \

@ahotrod
Copy link
Contributor

ahotrod commented Oct 14, 2019

@hlums
Is your configuration single or multi-GPU?
Using Pytorch==1.3.0 and Transformers=2.1.1?

The reason I ask is that with 2 x 1080Ti NVIDIAs trying to run_squad.py on XLNet & BERT models, I experience data-parallel-run and distributed-performance-reporting (key error) failures. Perhaps you have the solution to either/both?

@hlums
Copy link
Contributor

hlums commented Oct 15, 2019

@ahotrod I'm using Pytorch 1.2.0. I have 4 NVIDIA V100.
How are you running the script? Are you calling python -m torch.distributed.launch...? Can you try removing torch.distributed.launch? I think it's intended to be used for multi-node training in the way run_squad.py is written, although it can be used for multi-GPU training if we make some changes to run_squad.py.

@slayton58
Copy link
Contributor

@ahotrod I've been seeing key errors only when running eval in distributed -- training is fine (and I've run quite a few full 8xV100 distributed finetunings in the last few weeks), but I have to drop back to DataParallel for eval to work.

@ahotrod
Copy link
Contributor

ahotrod commented Oct 15, 2019

@hlums @slayton58 Thank you both for informative, helpful replies.

** Updated, hope I adequately explain my work-around **

I prefer distributed processing for the training speed-up, plus my latest data parallel runs have been loading one of <parameters & buffers> on cuda:1 and shutting down. As recommended I dropped the do_eval argument and ran my distributed shell script below, which worked fine. I then ran a do_eval script on a single GPU to generate the predictions_.json file, which I don't get from a distributed script when including do_eval (key error).

Here's my distributed fine-tuning script:

SQUAD_DIR=/media/dn/dssd/nlp/transformers/examples/squad1.1
export OMP_NUM_THREADS=6

python -m torch.distributed.launch --nproc_per_node=2 ./run_squad.py \
  --model_type xlnet \
  --model_name_or_path xlnet-large-cased \
  --do_train \
  --do_lower_case \
  --train_file ${SQUAD_DIR}/train-v1.1.json \
  --predict_file ${SQUAD_DIR}/dev-v1.1.json \
  --num_train_epochs 3 \
  --learning_rate 3e-5 \
  --max_seq_length 384 \
  --doc_stride 128 \
  --save_steps=10000 \
  --per_gpu_train_batch_size 1 \
  --gradient_accumulation_steps 4 \
  --output_dir ./runs/xlnet_large_squad1_dist_X \

which maxes-out my 2 x 1080Ti GPUs (0: hybrid, 1: open-frame cooling):

***** Running training *****
Num examples = 89993
Num Epochs = 3
Instantaneous batch size per GPU = 1
Total train batch size (w. parallel, distributed & accumulation) = 8
Gradient Accumulation steps = 4
Total optimization steps = 33747

NVIDIA-SMI 430.50       Driver Version: 430.50       CUDA Version: 10.1

0 GeForce GTX 1080Ti
0%    51C    P2   256W / 250W |  10166MiB / 11178MiB |    100%

1 GeForce GTX 1080Ti
35%   65C    P2   243W / 250W |  10166MiB / 11178MiB |     99%      

After 3 epochs & ~21 hours, here are the results, similar to @colanim :

***** Running evaluation  *****
Num examples = 11057
Batch size = 32
{
  "exact": 75.01419110690634,
  "f1": 82.13017516396678,
  "total": 10570,
  "HasAns_exact": 75.01419110690634,
  "HasAns_f1": 82.13017516396678,
  "HasAns_total": 10570
}

generated from my single GPU do_eval script pointing to the distributed fine-tuned model (path):

CUDA_VISIBLE_DEVICES=0  python run_squad.py \
  --model_type xlnet \
  --model_name_or_path ${MODEL_PATH} \
  --do_eval \
  --do_lower_case \
  --train_file ${SQUAD_DIR}/train-v1.1.json \
  --predict_file ${SQUAD_DIR}/dev-v1.1.json \
  --per_gpu_eval_batch_size 32 \
  --output_dir ${MODEL_PATH}

This model performs well in my Q&A application, but looking forward to @hlums pre-processing code, the imminent RoBERTa-large-SQuAD2.0, and perhaps one-day, ALBERT for the low-resource user that I am.

@hlums
Copy link
Contributor

hlums commented Oct 16, 2019

OK. Figured out the relative module import issue. Code is running now and should have the PR tomorrow if nothing else goes wrong.

@hlums
Copy link
Contributor

hlums commented Oct 17, 2019

PR is here #1549. My current result is
{
"exact": 85.45884578997162,
"f1": 92.5974600601065,
"total": 10570,
"HasAns_exact": 85.45884578997162,
"HasAns_f1": 92.59746006010651,
"HasAns_total": 10570
}

Still a few points lower than what's reported in the XLNet paper, but we made some progress. :)

@Swathygsb
Copy link

How to convert
cls_logits: (optional, returned if start_positions or end_positions is not provided)
to probabilities values between 0 to 1?

@stale
Copy link

stale bot commented Jan 19, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Jan 19, 2020
@stale stale bot closed this as completed Jan 26, 2020
cng420 pushed a commit to cng420/transformers that referenced this issue Nov 3, 2024
* Add support for global `onnxruntime`

- Allows custom JS runtimes to expose their own ONNX API.

* Update onnx.js

---------

Co-authored-by: Joshua Lochner <admin@xenova.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

7 participants