Skip to content

Commit

Permalink
Memory optimization for gpt_bitcode (#4) (huggingface#1513)
Browse files Browse the repository at this point in the history
Co-authored-by: Urszula Golowicz <urszula.golowicz@intel.com>
  • Loading branch information
2 people authored and Liangyx2 committed Jan 20, 2025
1 parent 3d926be commit d37b3cb
Showing 1 changed file with 90 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention, GPTBigCodeForCausalLM
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import (
GPTBigCodeAttention,
GPTBigCodeForCausalLM,
upcast_masked_softmax,
upcast_softmax,
)

from ...modeling_attn_mask_utils import GaudiAttentionMaskConverter

Expand Down Expand Up @@ -57,6 +62,90 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA is not None else None
self.block_size = 4096

def _attn(self, query, key, value, attention_mask=None, head_mask=None):
"""
This method should be deleted when https://github.com/huggingface/transformers/pull/34508 is merged.
Copied from GPTBigCodeAttention._attn: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
The only differences are:
- in self._attn, use torch.matmul instead of torch.baddbmm when the device used for query is not cpu
"""
dtype = query.dtype
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
upcast = dtype != softmax_dtype

unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
scale_factor = unscale**-1
if self.scale_attn_weights:
scale_factor /= self.head_dim**0.5

# MQA models: (batch_size, query_length, num_heads * head_dim)
# MHA models: (batch_size, num_heads, query_length, head_dim)
query_shape = query.shape
batch_size = query_shape[0]
key_length = key.size(-1)
if self.multi_query:
# (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
# -> (batch_size, query_length, num_heads, key_length)
query_length = query_shape[1]
attn_shape = (batch_size, query_length, self.num_heads, key_length)
attn_view = (batch_size, query_length * self.num_heads, key_length)
# No copy needed for MQA 2, or when layer_past is provided.
query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
else:
# (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length)
# -> (batch_size, num_heads, query_length, key_length)
query_length = query_shape[2]
attn_shape = (batch_size, self.num_heads, query_length, key_length)
attn_view = (batch_size * self.num_heads, query_length, key_length)
# Always copies
query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim)
# No copy when layer_past is provided.
key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length)

attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype)
if query.device.type == "cpu":
# This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588.
# The bug was fixed in https://github.com/pytorch/pytorch/pull/96086,
# but the fix has not been released as of pytorch version 2.0.0.
attn_weights = torch.zeros_like(attn_weights)
attn_weights = torch.baddbmm(attn_weights, query, key, beta=1, alpha=scale_factor).view(attn_shape)
else:
# Formula for torch.baddbmm: out = beta * attn_weights + scale_factor * (query ⋅ key)
# for beta = 0, it simplifies to: out = scale_factor * (query ⋅ key)
attn_weights = (torch.matmul(query, key) * scale_factor).view(attn_shape)

if upcast:
# Use a fused kernel to prevent a large overhead from casting and scaling.
# Sub-optimal when the key length is not a multiple of 8.
if attention_mask is None:
attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype)
else:
mask_value = self._get_mask_value(attn_weights.device, softmax_dtype)
attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype)
else:
if attention_mask is not None:
mask_value = self._get_mask_value(attn_weights.device, softmax_dtype)

# The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion.
attn_weights = torch.where(attention_mask, attn_weights, mask_value)

attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)

attn_weights = self.attn_dropout(attn_weights)

# Mask heads if we want to
if head_mask is not None:
if self.multi_query:
head_mask = head_mask.transpose(1, 2)
attn_weights = attn_weights * head_mask

if self.multi_query:
attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
else:
attn_output = torch.matmul(attn_weights, value)

return attn_output, attn_weights

def gaudi_flash_attn_v1(
self,
query_layer,
Expand Down

0 comments on commit d37b3cb

Please sign in to comment.