Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing a single example for BLOOM 176-B affects forward pass for other examples in a batch #18809

Closed
2 of 4 tasks
mayank31398 opened this issue Aug 29, 2022 · 9 comments
Closed
2 of 4 tasks
Labels

Comments

@mayank31398
Copy link
Contributor

mayank31398 commented Aug 29, 2022

System Info

  • transformers version: 4.21.2
  • Platform: Linux-4.18.0-305.25.1.el8_4.x86_64-x86_64-with-glibc2.17
  • Python version: 3.8.13
  • Huggingface_hub version: 0.9.1
  • PyTorch version (GPU?): 1.11.0 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: yes

Who can help?

@thomasw21, @younesbelkada This issue if for unexpected BLOOM outputs.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I wrote this script to do get the conditional NLL for the labels given the context.
Tried different batches with only the first example changing and rest of the examples fixed in the batch. However, after a certain point, the changing of first examples, affects the NLL for other examples.

This is not supposed to happen.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "bigscience/bloom"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    max_memory={0: '0GIB', 1: '51GIB', 2: '51GIB', 3: '51GIB',
                4: '51GIB', 5: '51GIB', 6: '51GIB', 7: '51GIB'},
    torch_dtype=torch.bfloat16,
)

model.eval()

def compute_gen_loss(lm_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    batch_size = labels.shape[0]
    shift_logits = lm_logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
    loss = loss_fct(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1)
    )
    loss = loss.reshape(batch_size, -1)
    loss = loss.sum(dim=-1) / (shift_labels != -100).sum(dim=-1)
    return loss


def pad_ids(arrays, padding, max_length=-1):
    if (max_length < 0):
        max_length = max(list(map(len, arrays)))

    arrays = [[padding] * (max_length - len(array)) +
              array for array in arrays]

    return arrays


def forward(text: list, labels: str, conditional: bool = True):
    input_tokens = tokenizer(text).input_ids
    label_tokens = tokenizer(labels).input_ids

    input_ids = [x + y for (x, y) in zip(input_tokens, label_tokens)]
    attention_mask = [(len(x) + len(y)) * [1]
                      for (x, y) in zip(input_tokens, label_tokens)]
    if (conditional):
        labels = [[-100] * len(x) + y for (x, y)
                  in zip(input_tokens, label_tokens)]
    else:
        labels = input_ids

    pad = 3
    input_ids = pad_ids(input_ids, pad)
    attention_mask = pad_ids(attention_mask, 0)
    # labels need to be on output device
    labels = pad_ids(labels, -100)

    input_ids = torch.tensor(input_ids)
    attention_mask = torch.tensor(attention_mask)
    labels = torch.tensor(labels)
    lm_logits = model(
        input_ids=input_ids,
        attention_mask=attention_mask
    ).logits

    print(compute_gen_loss(lm_logits, labels).cpu().tolist())

text = [
    "DeepSpeed",
    "DeepSpeed is a",
    "DeepSpeed is a machine",
    "DeepSpeed is a machine learning framework",
]
labels = [
    " is awesome.",
    " good person.",
    " that can wipe out the planet.",
    " for generating memes.",
]
forward(text, labels)

labels[0] = " is awesome. really awesome"
forward(text, labels)

labels[0] = " is awesome. really awesome. Try it."
forward(text, labels)

labels[0] = " is awesome. really awesome. Try it. You'll be surprised"
forward(text, labels)

labels[0] = " is awesome. really awesome. Try it. You'll be surprised. BLOOM was trained using DeepSpeed."
forward(text, labels)

labels[0] = " is awesome. really awesome. Try it. You'll be surprised. BLOOM was trained using DeepSpeed. Oh no the values are bugging out now."
forward(text, labels)
[4.8125, 5.1875, 3.296875, 5.09375]
[5.625, 5.1875, 3.296875, 5.09375]
[4.375, 5.1875, 3.296875, 5.09375]
[4.0625, 5.1875, 3.28125, 5.09375]
[3.953125, 5.1875, 3.28125, 5.0625]
[4.25, 5.1875, 3.296875, 5.09375]

Value drops from 3.29 to 3.28 in column 2 when only example for column 0 is changed. Even column 3 changes in last case.
Only column 0 is supposed to change here.

Expected behavior

[4.8125, 5.1875, 3.296875, 5.09375]
[5.625, 5.1875, 3.296875, 5.09375]
[4.375, 5.1875, 3.296875, 5.09375]
[4.0625, 5.1875, 3.296875, 5.09375]
[3.953125, 5.1875, 3.296875, 5.09375]
[4.25, 5.1875, 3.296875, 5.09375]
@mayank31398 mayank31398 changed the title Changing a single example affects forward pass for other examples in a batch Changing a single example for BLOOM 176-B affects forward pass for other examples in a batch Aug 29, 2022
@thomasw21
Copy link
Contributor

thomasw21 commented Aug 29, 2022

Hey! It's a bit hard to run a testing env with bloom, can you share a reproductible script with a smaller model?

This looks like some instabilities from torch.bfloat16, and I'm willing to bet that those values come from there (both 3.28 occurences are exactly the same, so seems like a rounding error to me, we can perhaps check that those values are consecutive values in bfloat16, ie there's no value between 3.28 and 3.29). What I think might be happening is you're adding pad as you increase the length of the labels and those pad values change the behaviour of previous values. I don't think we have much control over this as this relies on torch operators usually.

Also if you can run on main that'd be great, typically #18344 hasn't been incorporated yet in a release and I think it fixed a bunch of instabilities.

@mayank31398
Copy link
Contributor Author

Thanks @thomasw21 for taking a look at this. I will try to reproduce this with a smaller model (say GPT-2) and get back on this. I will also try main branch.

@mayank31398
Copy link
Contributor Author

Also, since there are no batch-norm ops in BLOOM. I don't really understand why this should happen. Also, since the pads have been given an attention mask = 0. Shouldn't the output be the same?
Maybe I am understanding this incorrectly.

@younesbelkada
Copy link
Contributor

hi @mayank31398 !
Thanks for pointing out this issue 💪
If I wrap up what I have understood from your issue, when doing batched generation changing the value of one of the label changes the value of the loss function. If I understood correctly the labels are not used when inferring there, so the problem should occur when computing the loss (i.e., the input text is always fixed, right?).
I tried your script on the main branch using gpt2 as below:

import torch

from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

# lm_logits = torch.randn((4, 11, 250880), dtype=torch.bfloat16)

def compute_gen_loss(lm_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    batch_size = labels.shape[0]
    shift_logits = lm_logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
    loss = loss_fct(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1)
    )
    loss = loss.reshape(batch_size, -1)
    loss = loss.sum(dim=-1) / (shift_labels != -100).sum(dim=-1)
    return loss

def pad_ids(arrays, padding, max_length=-1):
    if (max_length < 0):
        max_length = max(list(map(len, arrays)))

    arrays = [[padding] * (max_length - len(array)) +
              array for array in arrays]

    return arrays


def forward(text: list, labels: str, conditional: bool = True):
    input_tokens = tokenizer(text).input_ids
    label_tokens = tokenizer(labels).input_ids

    input_ids = [x + y for (x, y) in zip(input_tokens, label_tokens)]
    attention_mask = [(len(x) + len(y)) * [1]
                      for (x, y) in zip(input_tokens, label_tokens)]
    if (conditional):
        labels = [[-100] * len(x) + y for (x, y)
                  in zip(input_tokens, label_tokens)]
    else:
        labels = input_ids

    pad = 3
    input_ids = pad_ids(input_ids, pad)
    attention_mask = pad_ids(attention_mask, 0)
    # labels need to be on output device
    labels = pad_ids(labels, -100)

    input_ids = torch.tensor(input_ids)
    attention_mask = torch.tensor(attention_mask)
    labels = torch.tensor(labels)
    lm_logits = model(
        input_ids=input_ids,
        attention_mask=attention_mask
    ).logits

    print(compute_gen_loss(lm_logits, labels).cpu().tolist())

text = [
    "DeepSpeed",
    "DeepSpeed is a",
    "DeepSpeed is a machine",
    "DeepSpeed is a machine learning framework",
]
labels = [
    " is awesome.",
    " good person.",
    " that can wipe out the planet.",
    " for generating memes.",
]
forward(text, labels)

labels[0] = " is awesome. really awesome"
forward(text, labels)

labels[0] = " is awesome. really awesome. Try it."
forward(text, labels)

labels[0] = " is awesome. really awesome. Try it. You'll be surprised"
forward(text, labels)

labels[0] = " is awesome. really awesome. Try it. You'll be surprised. BLOOM was trained using DeepSpeed."
forward(text, labels)

labels[0] = " is awesome. really awesome. Try it. You'll be surprised. BLOOM was trained using DeepSpeed. Oh no the values are bugging out now."
forward(text, labels)

and getting

[10.3125, 7.0, 3.609375, 7.65625]
[8.25, 7.0, 3.609375, 7.65625]
[6.84375, 7.0, 3.609375, 7.65625]
[3.78125, 7.09375, 6.9375, 8.5625]
[4.34375, 9.5, 8.6875, 10.75]
[4.53125, 9.6875, 9.0, 12.125]

I suspect that logits may be flaky when using half-precision models, therefore I second what @thomasw21
suspected ;) !

@mayank31398
Copy link
Contributor Author

Hey, first of all: sorry for late reply.
Thanks for trying out my example with gpt2 @younesbelkada
Any way to get around this then?
I guess computing logits in bf16 might not be the best we can do?

@thomasw21
Copy link
Contributor

thomasw21 commented Sep 14, 2022

Okay I think gpt2 test isn't instability. Essentially it's absolute positional embeddings that's screwing with you as you move things to the right and adding padding to the left as you increase the label size, which is why you see big shifts in the loss.

I do think that the bloom test is instability. Typically 3.28125 and 3.296875 are consecutive.

>>> import torch
>>> torch.set_printoptions(precision=10)
>>> torch.frombuffer(bytes(np.array([83,64], np.int8)), dtype=torch.bfloat16)
tensor([3.2968750000], dtype=torch.bfloat16)
>>> torch.frombuffer(bytes(np.array([82,64], np.int8)), dtype=torch.bfloat16) # replace 83 with 82
tensor([3.2812500000], dtype=torch.bfloat16)

>>> torch.frombuffer(bytes(np.array([-94,64], np.int8)), dtype=torch.bfloat16)
tensor([5.0625000000], dtype=torch.bfloat16)
>>> torch.frombuffer(bytes(np.array([-93,64], np.int8)), dtype=torch.bfloat16)
tensor([5.0937500000], dtype=torch.bfloat16)

So as you said you can try computing the logits in fp32, which will increase precision (but will be slower). There's a bit of a workaround as you need to cast the embedding layers to fp32 and such.

@younesbelkada
Copy link
Contributor

younesbelkada commented Sep 14, 2022

Everything makes sense in your explanation @thomasw21 ! Missed the absolute positional embedding part. Thanks for explaining it 💪

@mayank31398
Copy link
Contributor Author

I guess this is not a fixable problem then right?
I think even in BLOOM AliBi might be screwing up with attention values right?
So, even if we have padded, the result will change.
Thanks for clarificatioon @thomasw21 .

I think we can close this?

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants