Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache updating when use_cache = False #32843

Closed
2 of 4 tasks
ciaran-regan-ie opened this issue Aug 16, 2024 · 8 comments · Fixed by #32863
Closed
2 of 4 tasks

Cache updating when use_cache = False #32843

ciaran-regan-ie opened this issue Aug 16, 2024 · 8 comments · Fixed by #32863

Comments

@ciaran-regan-ie
Copy link

ciaran-regan-ie commented Aug 16, 2024

System Info

  • transformers version: 4.44.0
  • Platform: macOS-14.5-arm64-arm-64bit
  • Python version: 3.10.0
  • Huggingface_hub version: 0.24.5
  • Safetensors version: 0.4.4
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm experimenting with shuffling layers in a pre-trained model. The layer_idx inside the Attention object makes this difficult as described in this issue. To work around this, I'm setting use_cache = False, however, even with use_cache = False, an error is occurring as past_key_value.update is being called in the Attention forward pass. A simple solution would be to use use_cache in the forward pass by adding the following and logic:

if past_key_value is not None and use_cache:
    cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
    key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Here is my code to reproduce. The first run through will run because the layers have not switched, but the second run will fail as the cache attempts to update.

from transformers import AutoTokenizer, AutoModelForCausalLM
from utils import *
import random
import torch

def run_with_custom_order(model, tokenizer, device, prompts, order):
    original_layers = model.model.layers
    layer_dict = {i: layer for i, layer in enumerate(original_layers)}
    shuffled_layers = torch.nn.ModuleList([layer_dict[i] for i in order])
    model.model.layers = shuffled_layers

    for prompt in prompts:
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        input_length = inputs["input_ids"].shape[1]
        outputs = model.generate(**inputs, max_new_tokens=20, pad_token_id=tokenizer.eos_token_id, use_cache=False)
        generated_text = tokenizer.decode(outputs[0, input_length:], skip_special_tokens=True)
        print(f"{generated_text}")

    model.model.layers = original_layers
    pass

def main():
    device = (
        "cuda" if torch.cuda.is_available() else 
        "mps" if torch.backends.mps.is_available() else 
        "cpu"
    )    
    llm_name = "microsoft/Phi-3-mini-4k-instruct"
    tokenizer = AutoTokenizer.from_pretrained(llm_name)
    model = AutoModelForCausalLM.from_pretrained(llm_name, torch_dtype=torch.bfloat16, trust_remote_code=True)
    model.to(device)
    model.config.pad_token_id = tokenizer.eos_token_id
    num_layers = len(model.model.layers)

    # Load questions and answers
    dataset_type = "mmlu"  # Change this to "default" for capitals dataset
    num_questions = 1
    questions, _ = load_qa(dataset_type, num_questions)
    prompts = [f"<|user|>\n{question}\nChoose A, B, C, or D:<|end|>\n<|assistant|>" for question in questions]
    
    order = list(range(num_layers))
    run_with_custom_order(model, tokenizer, device, prompts, order)

    random.shuffle(order)
    run_with_custom_order(model, tokenizer, device, prompts, order)

    pass

if __name__ == "__main__":
    main()

Expected behavior

When use_cache = False, the cache should not be updating, right?

Happy to help with PRs if you feel its necessary!

@amyeroberts
Copy link
Collaborator

cc @gante too for the cache

@nickfraser
Copy link

Related, but not sure if this should be a separate issue. The problem is actually slightly more general than what you've described. For example, the forward function of LlamaAttention contains no reference to use_cache.

This has the following consequences:

  1. If past_key_value is passed to LlamaAttention, the issue stated above will occur (i.e., the cache will always be updated)
  2. If past_key_value is not passed (i.e., past_key_value=None), the returned past_key_value will also always be None, regardless of the value of use_cache

@nickfraser
Copy link

nickfraser commented Aug 16, 2024

Also FYI, it appears this issue has existed since the new cache structure was introduced in v4.36, the correct behaviour existed in prior versions when the tuples were used, e.g., here.

@gante
Copy link
Member

gante commented Aug 16, 2024

Hi @ciaran-regan-ie (and @nickfraser )👋 Thank you for opening the issue and elaborating on the problem!

Before taking your comments and projects into consideration, let me share my view (and the context behind some changes in v4.36). We moved towards having a custom object to store the cache, the Cache classes. Many use cases, such as using StaticCache for torch.compile, require passing an instantiated object, even if its actual contents are empty. As such, use_cache lost some of its importance: we now often have the case where we pass an empty cache, which implies use_cache=True. We have also identified some cases where use_cache must be False, such as train time. I dislike implicit/redundant flags, they tend to create problems, so I would love to remove in when (and if) we release transformers v5.

Our code, therefore, gravitated to its current state where we check it in the core modeling class and create a new cache instance if needed. From that point onwards, we assume that past_key_values is not None == we want to use cache. Why wouldn't we? We are either passing the cache object or telling the model to create a new one.

From what I'm reading in your comments, my assumption may be incorrect!

  • @ciaran-regan-ie in your case, I think the problem is in generate. It creates a cache by default and passes it to the model, we are not checking use_cache. Will open a PR to fix it! 💪
  • @nickfraser could you elaborate on your use case, if my assumption is incorrect? 🤗

@nickfraser
Copy link

Hi @gante,

Thanks for the detailed reply. Also, please feel free to tell me to open a new issue if that is more appropriate. I understand the new behaviour. In my case, I was calling sub-layers of a Llama-based model directly (with use_cache=True), for some research work, which causes some strange behaviour. In this case, the cache instantiation code in LlamaModel is bypassed and LlamaDecoderLayer returns (torch.Tensor, None).

I find this behaviour to be quite unintuitive, but I accept that this is a niche use-case.

@gante
Copy link
Member

gante commented Aug 16, 2024

@nickfraser If I understand correctly, you were expecting LlamaDecoderLayer to return (torch.Tensor, Cache) when use_cache=True (Cache being a new instance), correct? I would expect a cache to be returned too, given the input argument name 😅

Shifting the cache instantiation from the inner-most block (prior to v4.36) to the outer-most block was a hard requirement to enable torch.compile, but it does conflict with reasonable expectations for use_cache.

Where's what I'm thinking to do:

  1. Deprecate use_cache except in the core modeling class (e.g. LlamaModel) and classes that use it (e.g. LlamaForCausalLM). Internal blocks will essentially assume use_cache = past_key_values is not None. Having multiple places where a new cache can be instantiated will bloat the modeling code and is prone to errors, and I think we can assume that folks that use internal layers are power users and know how to instantiate a cache :)
  2. Add a clear message in the deprecation warning: Cache instantiation will only happen at the core model level. If a cache is to be used, use past_key_values.

WDYT? It should make things much cleaner from a user perspective, while being manageable on our end 🤗

(cc @ArthurZucker )

@nickfraser
Copy link

If I understand correctly, you were expecting LlamaDecoderLayer to return (torch.Tensor, Cache) when use_cache=True (Cache being a new instance), correct?

Yes, exactly.

Shifting the cache instantiation from the inner-most block (prior to v4.36) to the outer-most block was a hard requirement to enable torch.compile

Makes sense.

WDYT? It should make things much cleaner from a user perspective, while being manageable on our end 🤗

Your suggestion makes a lot of sense to me - sounds great! Thanks for being so amenable too! <3

@ciaran-regan-ie
Copy link
Author

@gante Thank you so much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants