Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

PretrainedTransformerMismatchedIndexer fails silently when given empty strings as input. #4612

Closed
10 tasks done
dwadden opened this issue Aug 28, 2020 · 6 comments · Fixed by #4615
Closed
10 tasks done
Labels

Comments

@dwadden
Copy link

dwadden commented Aug 28, 2020

Checklist

  • I have verified that the issue exists against the master branch of AllenNLP.
  • I have read the relevant section in the contribution guide on reporting bugs.
  • I have checked the issues list for similar or identical bug reports.
  • I have checked the pull requests list for existing proposed fixes.
  • I have checked the CHANGELOG and the commit log to find out if the bug was already fixed in the master branch.
  • I have included in the "Description" section below a traceback from any exceptions related to this bug.
  • I have included in the "Related issues or possible duplicates" section beloew all related issues and possible duplicate issues (If there are none, check this box anyway).
  • I have included in the "Environment" section below the name of the operating system and Python version that I was using when I discovered this bug.
  • I have included in the "Environment" section below the output of pip freeze.
  • I have included in the "Steps to reproduce" section below a minimally reproducible example.

Description

When the PretrainedTransformerMismatchedEmbedder is used to embed an empty string '', running backward on the output of the embedder produces nan gradients in the transformer parameters, but it's tough to track down where they came from. It's easier to see with an example; see "Steps to Reproduce".

Suggested fix: The PretrainedTransformerMismatchedEmbedder should throw an error or give a warning when given an empty string as input. I'm happy to implement this if I can get some guidance on the appropriate files to change.

The output of pip freeze is long and (I think) uninformative, so I've moved it to the bottom of this bug report.

Related issues or possible duplicates

  • None

Steps to reproduce

Runnable example below. The first example will not produce nan gradients, but the second one will.

import torch

from allennlp import data
from allennlp.data import fields
from allennlp import modules


def check_nan_grads(words):
    "Encode a list of words, take a gradient, and check for NaN's."
    print(f"Checking {words}.")
    # Create indexer and embedder.
    tok_indexers = {"bert": data.token_indexers.PretrainedTransformerMismatchedIndexer(
        "bert-base-cased")}
    token_embedder = modules.token_embedders.PretrainedTransformerMismatchedEmbedder(
        "bert-base-cased")
    embedder = modules.text_field_embedders.BasicTextFieldEmbedder({"bert": token_embedder})

    # Convert words to tensor dict.
    vocab = data.Vocabulary()
    text_field = fields.TextField(
        [data.Token(word) for word in words], tok_indexers)
    text_field.index(vocab)
    token_tensor = text_field.as_tensor(text_field.get_padding_lengths())
    tensor_dict = text_field.batch_tensors([token_tensor])

    # Run forward pass. We need a scalar to take the gradient of, so just take the mean of the
    # embeddings.
    output = embedder(tensor_dict)
    loss = output.mean()
    loss.backward()

    # Check whether this produces an NaN in the model parameters.
    for name, param in embedder.named_parameters():
        grad = param.grad
        if grad is not None and torch.any(torch.isnan(param.grad)):
            print("Found NaN grad.")
            print("Offending tensor_dict:")
            print(tensor_dict)
            print()
            return

    print("No NaN's.")
    print()


####################

# This works fine.
example_safe = ["An", "example"]
check_nan_grads(example_safe)

# This produces NaN grads because of the empty string.
example_bad_empty = ["An", "", "example"]
check_nan_grads(example_bad_empty)

# This produces NaN grads because there's a weird character the indexer doesn't know about.
weird_character = "\uf732\uf730\uf730\uf733"
print(f"Weird character: {weird_character}.")
example_bad_unicode = ["A", weird_character, "example"]
check_nan_grads(example_bad_unicode)

Environment

OS: Linux
Python version: 3.7.6

Pip freeze output: ``` -e git+https://github.com/allenai/allennlp.git@73220d7#egg=allennlp -e git+https://github.com/allenai/allennlp-models.git@a730fed9424bcbe21186fc7866b195ea9ac7ecc5#egg=allennlp_models attrs==19.3.0 autopep8 @ file:///tmp/build/80754af9/autopep8_1592412889138/work backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work beautifulsoup4==4.8.1 blis==0.4.1 boto3==1.14.20 botocore==1.17.20 catalogue==1.0.0 certifi==2020.6.20 chardet==3.0.4 click==7.1.2 conllu==3.0 cymem==2.0.3 decorator==4.4.2 docutils==0.15.2 filelock==3.0.12 flake8==3.8.3 flaky==3.7.0 future==0.18.2 h5py==2.10.0 idna==2.10 importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1593446408836/work ipdb==0.11 ipython==7.9.0 ipython-genutils==0.2.0 javapackages==4.3.2 jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1592619900914/work jmespath==0.10.0 joblib==0.16.0 jsonnet @ file:///home/conda/feedstock_root/build_artifacts/jsonnet_1590349750875/work jsonpickle==1.4.1 lxml==4.5.2 mccabe==0.6.1 more-itertools==8.4.0 murmurhash==1.0.2 mypy @ file:///tmp/build/80754af9/mypy_1593442617121/work mypy-extensions==0.4.3 nltk==3.5 numpy==1.19.0 overrides==3.1.0 packaging==20.4 pandas==0.25.2 parso==0.7.0 pexpect==4.8.0 pickleshare==0.7.5 plac==1.1.3 pluggy==0.13.1 preshed==3.0.2 prompt-toolkit==2.0.10 protobuf==3.12.2 psutil==5.7.0 ptyprocess==0.6.0 py==1.9.0 py-rouge==1.1 pycodestyle==2.6.0 pyflakes==2.2.0 Pygments==2.6.1 pyparsing==2.4.7 pytest==5.4.3 python-dateutil==2.8.1 python-Levenshtein==0.12.0 pytz==2020.1 PyXB==1.2.4 regex==2020.6.8 requests==2.24.0 responses==0.10.15 rope==0.17.0 s3transfer==0.3.3 sacremoses==0.0.43 scikit-learn==0.23.1 scipy==1.5.1 semantic-version==2.8.5 sentencepiece==0.1.91 six @ file:///home/conda/feedstock_root/build_artifacts/six_1590081179328/work soupsieve==2.0.1 spacy==2.2.4 srsly==1.0.2 tensorboardX==2.1 thinc==7.4.0 threadpoolctl==2.1.0 tokenizers==0.8.1rc1 toml @ file:///tmp/build/80754af9/toml_1592853716807/work torch==1.6.0 tqdm==4.47.0 traitlets==4.3.3 transformers==3.0.2 typed-ast==1.4.1 typing-extensions @ file:///tmp/build/80754af9/typing_extensions_1592847887441/work urllib3==1.25.9 wasabi==0.7.0 wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1592931742287/work word2number==1.1 zipp==3.1.0 ```
@dwadden dwadden added the bug label Aug 28, 2020
@dwadden dwadden changed the title PretrainedTransformerMismatchedIndexer fails silently when given empty strings as input. PretrainedTransformerMismatchedIndexer fails silently when given empty strings as input. Aug 28, 2020
@matt-gardner
Copy link
Contributor

Thanks for the thorough bug report! Just FYI, I fixed the pip freeze output to be what was intended; you can see that it's now in a drop down, so it's not so annoying.

Can you also include what print(tensor_dict) looks like? I'm not sure what will happen with that empty string. Is it just all padding, and that's what's causing nans? I'm guessing so.

I agree that throwing an error somewhere is the right fix. I'm not certain where that error should be, but right now I'm thinking probably in the tokenizer. Seeing the output of print(tensor_dict) might make me change my mind, though.

@dwadden
Copy link
Author

dwadden commented Aug 28, 2020

Ah, thanks for fixing pip freeze.

It turns out that, in addition to empty strings, weird unicode characters can also mess things up; presumably they're not recognized by the token indexer. I added an example in the code snippet above to show what happens.

I've included a tensor_dict printout below. I think it might be the -1 entries in offsets that's messing things up. Let me know what you think is the best spot to throw an error.

Checking ['An', '', 'example'].
Found NaN grad.
Offending tensor_dict:
{'bert': {'mask': tensor([[True, True, True]]),
          'offsets': tensor([[[ 1,  1],
                              [-1, -1],
                              [ 2,  2]]]),
          'token_ids': tensor([[ 101, 1760, 1859,  102]]),
          'type_ids': tensor([[0, 0, 0, 0]]),
          'wordpiece_mask': tensor([[True, True, True, True]])}}

@matt-gardner
Copy link
Contributor

Ok, figured it out. We need a torch.clamp_min(..., 1) on this line:

orig_embeddings = span_embeddings_sum / span_embeddings_len

That resolves the issue. Can you open a PR for this that includes a simple test based on the minimal example you gave above? That would be awesome.

@dwadden
Copy link
Author

dwadden commented Aug 30, 2020

I changed that line to orig_embeddings = torch.clamp_min(span_embeddings_sum / span_embeddings_len, 1), but I'm still getting nan gradients (I wrote a unit test for this). Is there something else that needs to be changed?

I've got a PR for this, but I didn't want to submit since the test suite wouldn't pass if the PR were accepted as is. Let me know if I should submit anyhow.

@matt-gardner
Copy link
Contributor

Sorry I wasn't specific enough; it's the division by zero that's the problem, so you need to clamp the lengths to a min of 1. So, orig_embeddings = span_embeddings_sum / torch.clamp_min(span_embeddings_len, 1).

@dwadden
Copy link
Author

dwadden commented Aug 30, 2020

PR submitted: #4615

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants