Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Falcon Mamba] Unexpected model output with use_cache=False and model.train() #33234

Closed
2 of 4 tasks
jploski opened this issue Aug 31, 2024 · 5 comments · Fixed by #33195
Closed
2 of 4 tasks

[Falcon Mamba] Unexpected model output with use_cache=False and model.train() #33234

jploski opened this issue Aug 31, 2024 · 5 comments · Fixed by #33195
Labels

Comments

@jploski
Copy link

jploski commented Aug 31, 2024

System Info

  • transformers version: 4.45.0.dev0
  • Platform: Linux-4.19.0-8-amd64-x86_64-with-glibc2.28
  • Python version: 3.10.9
  • Huggingface_hub version: 0.23.2
  • Safetensors version: 0.4.3
  • Accelerate version: 0.33.0
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    • distributed_type: NO
    • mixed_precision: no
    • use_cpu: False
    • debug: False
    • num_processes: 1
    • machine_rank: 0
    • num_machines: 1
    • rdzv_backend: static
    • same_network: False
    • main_training_function: main
    • enable_cpu_affinity: False
    • downcast_bf16: False
    • tpu_use_cluster: False
    • tpu_use_sudo: False
  • PyTorch version (GPU?): 2.4.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA GeForce RTX 3090

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

#!/usr/bin/env python3

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = './falcon-mamba-7b-instruct'
#model_id = './mamba-130m-hf'

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
model.tokenizer = tokenizer

input_ids = tokenizer("A falcon mamba is a", return_tensors="pt").input_ids.to("cuda")

def test_generate(use_cache, train):
    model.train(mode=train)

    print(f"test_generate with use_cache={use_cache} and model.training={train}")

    outputs = model.generate(input_ids, do_sample=False, max_new_tokens=1, return_dict_in_generate=True, output_logits=True, use_cache=use_cache)

    logits = outputs['logits'][0]
    logits_sorted, indices = torch.sort(logits, descending=True)
    for i in range(10):
        tok = tokenizer.decode([ indices[-1][i].item() ])
        print(f"{indices[-1][i].item()} {logits_sorted[-1][i].item()} [{tok}]")


test_generate(True, False)
test_generate(False, False)
test_generate(True, True)
test_generate(False, True)

Observed (unexpected) output:

1842 137.0 [ type]
22277 136.0 [ snake]
27184 136.0 [ fictional]
45271 136.0 [ venom]
829 135.0 [ very]
2794 135.0 [ fast]
3766 135.0 [ highly]
6114 135.0 [ species]
12318 135.0 [ hybrid]
17846 135.0 [ deadly]
test_generate with use_cache=False and model.training=False
1842 137.0 [ type]
22277 136.0 [ snake]
27184 136.0 [ fictional]
45271 136.0 [ venom]
829 135.0 [ very]
2794 135.0 [ fast]
3766 135.0 [ highly]
6114 135.0 [ species]
12318 135.0 [ hybrid]
17846 135.0 [ deadly]
test_generate with use_cache=True and model.training=True
1842 137.0 [ type]
22277 136.0 [ snake]
27184 136.0 [ fictional]
45271 136.0 [ venom]
829 135.0 [ very]
2794 135.0 [ fast]
3766 135.0 [ highly]
6114 135.0 [ species]
12318 135.0 [ hybrid]
17846 135.0 [ deadly]
test_generate with use_cache=False and model.training=True
204 118.5 [ ]
627 118.0 [ new]
1842 118.0 [ type]
4307 117.5 [ gift]
241 117.0 [ a]
822 117.0 [ good]
914 117.0 [ great]
986 117.0 [ high]
1733 117.0 [ special]
829 116.5 [ very]

Expected behavior

I would expect the logits produced by forward pass to be the same regardless of whether model.train(True) has been invoked or use_cache is False or True. This holds true for mamba-130m-hf. However, when running the provided test script with falcon-mamba-7b-instruct, the output for the case model.train(True) and use_cache=False is different from the other outputs.

Practical relevance: discovered during ORPO training - where model.training is True and use_cache=False. The outputs and loss calculation during training do not match the outputs observed after training in evaluation mode.

@jploski jploski added the bug label Aug 31, 2024
@vasqu
Copy link
Contributor

vasqu commented Aug 31, 2024

#33195 should resolve this, the original kernels in the mamba-ssm package miss some additional rms normalization.

@jploski
Copy link
Author

jploski commented Aug 31, 2024

#33195 should resolve this, the original kernels in the mamba-ssm package miss some additional rms normalization.

Thanks for the link. I tried with that PR, the output changes/improves, but it is still not identical with the other three cases:

test_generate with use_cache=False and model.training=True
1842 141.0 [ type]
22277 140.0 [ snake]
45271 140.0 [ venom]
1902 139.0 [ large]
2794 139.0 [ fast]
6114 139.0 [ species]
829 138.0 [ very]
916 138.0 [ long]
1385 138.0 [ small]
1462 138.0 [ non]

@vasqu
Copy link
Contributor

vasqu commented Sep 1, 2024

cc @younesbelkada

@younesbelkada
Copy link
Contributor

@jploski the output is not exactly the same but looks coherent compared to the previous results:

204 118.5 [ ]
627 118.0 [ new]
1842 118.0 [ type]
4307 117.5 [ gift]
241 117.0 [ a]
822 117.0 [ good]
914 117.0 [ great]
986 117.0 [ high]
1733 117.0 [ special]
829 116.5 [ very]

vs

1842 141.0 [ type]
22277 140.0 [ snake]
45271 140.0 [ venom]
1902 139.0 [ large]
2794 139.0 [ fast]
6114 139.0 [ species]
829 138.0 [ very]
916 138.0 [ long]
1385 138.0 [ small]
1462 138.0 [ non]

I'd suspect there are some numerical differences between the kernel and the non-kernel path which sums up and ultimately ends up not predicting exactly the same tokens.

@jploski
Copy link
Author

jploski commented Sep 1, 2024

I'd suspect there are some numerical differences between the kernel and the non-kernel path which sums up and ultimately ends up not predicting exactly the same tokens.

@younesbelkada A valid point, but perplexity over Wiki text seems significantly (~16%) worse for the train=True case. I evaluated using the following script:

#!/usr/bin/env python3

import torch
import os
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = './falcon-mamba-7b-instruct'
#model_id = './mamba-130m-hf'

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)

# wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
# unzip wikitext-2-raw-v1.zip

with open("./wiki.test.raw", encoding="utf-8") as f:
    wiki_text = f.read()

input_ids = tokenizer(wiki_text, return_tensors="pt").input_ids.to("cuda")

def test_ppl(train):
    model.train(train)

    # Adapted from https://huggingface.co/docs/transformers/en/perplexity
    stride = 512
    seq_len = input_ids.size(1)
    nlls = []
    prev_end_loc = 0
    for begin_loc in tqdm(range(0, seq_len, stride)):
        end_loc = min(begin_loc + stride, seq_len)
        trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
        chunk_input_ids = input_ids[0, begin_loc:end_loc].unsqueeze(0).to('cuda')
        target_ids = chunk_input_ids.clone()
        target_ids[:,:-trg_len] = -100

        with torch.no_grad():
            outputs = model(chunk_input_ids, labels=target_ids, use_cache=False)

            # loss is calculated using CrossEntropyLoss which averages over valid labels
            # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
            # to the left by 1.
            neg_log_likelihood = outputs.loss

        nlls.append(neg_log_likelihood)

        prev_end_loc = end_loc
        if end_loc == seq_len:
            break

    ppl = torch.exp(torch.stack(nlls).mean())
    print(f"train={train}, PPL={ppl.item()}")

test_ppl(False)
test_ppl(True)

With falcon-mamba-7b-instruct I'm getting PPL=7.3546 for train=False and PPL=8.5525 for train=True.

With mamba-130m-hf there is only a slight difference - PPL=27.1124 for train=False and PPL=27.1174 for train=True.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants