diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 20cf008548..7c9b69e212 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -600,6 +600,7 @@ def generate( if model_kwargs["reduce_recompile"]: assert generation_config.bucket_size if generation_config.reuse_cache: + assert self.config.model_type in ["llama"], "reuse_cache only supported by llama at the moment" assert generation_config.bucket_size <= 0, "reuse_cache and bucketing flags set together" if generation_config.static_shapes: