From 13f0aedd522c9a6310cb337135b54a4e22815cce Mon Sep 17 00:00:00 2001 From: Yeonsil Yoon Date: Wed, 25 Sep 2024 14:11:44 -0700 Subject: [PATCH] Update bloom attention forward reshape follwing the transformer change --- .../habana/transformers/models/bloom/modeling_bloom.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/optimum/habana/transformers/models/bloom/modeling_bloom.py b/optimum/habana/transformers/models/bloom/modeling_bloom.py index c06d42e34d..5b0a770451 100644 --- a/optimum/habana/transformers/models/bloom/modeling_bloom.py +++ b/optimum/habana/transformers/models/bloom/modeling_bloom.py @@ -137,11 +137,9 @@ def gaudi_bloom_attention_forward( # 3 x [batch_size, num_heads, seq_length, head_dim] query_layer, key_layer, value_layer = self._reshape(fused_qkv) - batch_size, q_length, _, _ = query_layer.shape - - query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) - key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length) - value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) + key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(1, 2) + value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) # Collapse views to improve performance on HPU query_layer = query_layer.contiguous()