Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finish adding support for torch.compile dynamic shapes #30919

Merged
merged 1 commit into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,13 +615,17 @@ def forward(
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attn_pdrop if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
is_causal=is_causal,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,14 +724,18 @@ def forward(
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
is_causal=is_causal,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down
8 changes: 5 additions & 3 deletions src/transformers/models/jetmoe/modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,15 +624,17 @@ def forward(
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
# relying on the `is_causal` argument.
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
is_causal=is_causal,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down