Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor mistral and phi3 #12605

Merged
merged 3 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 16 additions & 39 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,9 @@ def _optimize_pre(model, qtype=None):
elif model.config.model_type == "mllama":
from ipex_llm.transformers.models.mllama import merge_qkv
model.apply(merge_qkv)
elif model.config.model_type == "mistral":
from ipex_llm.transformers.models.mistral import merge_qkv
model.apply(merge_qkv)
elif model.config.model_type == "minicpm":
from ipex_llm.transformers.models.minicpm import merge_qkv, apply_residual_scale
model.apply(merge_qkv)
Expand Down Expand Up @@ -1901,43 +1904,17 @@ def _optimize_post(model, lightweight_bmm=False):
else:
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
if version.parse(trans_version) >= version.parse("4.36.0"):
from ipex_llm.transformers.models.mistral import mistral_model_forward_4_36
if version.parse(trans_version) >= version.parse("4.39.0"):
from ipex_llm.transformers.models.mistral import \
mistral_attention_forward_4_39
convert_forward(model,
module.MistralAttention,
mistral_attention_forward_4_39
)
else:
from ipex_llm.transformers.models.mistral import mistral_attention_forward_4_36
convert_forward(model,
module.MistralAttention,
mistral_attention_forward_4_36
)
convert_forward(model,
module.MistralModel,
mistral_model_forward_4_36
)
convert_forward(model,
module.MistralRMSNorm,
llama_rms_norm_forward)
convert_forward(model,
module.MistralMLP,
llama_mlp_forward)
else:
from ipex_llm.transformers.models.mistral import mistral_attention_forward
convert_forward(model,
module.MistralAttention,
mistral_attention_forward
)
convert_forward(model,
module.MistralRMSNorm,
llama_rms_norm_forward)
convert_forward(model,
module.MistralMLP,
llama_mlp_forward)

from ipex_llm.transformers.models.mistral import mistral_model_forward
from ipex_llm.transformers.models.mistral import mistral_attention_forward
from ipex_llm.transformers.models.common import rms_norm_forward
from ipex_llm.transformers.models.common import mlp_silu_forward

convert_forward(model, module.MistralModel, mistral_model_forward)
convert_forward(model, module.MistralAttention, mistral_attention_forward)
convert_forward(model, module.MistralSdpaAttention, mistral_attention_forward)
convert_forward(model, module.MistralRMSNorm, rms_norm_forward)
convert_forward(model, module.MistralMLP, mlp_silu_forward)
elif model.config.model_type == "gemma":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
Expand Down Expand Up @@ -2078,8 +2055,8 @@ def safe_bmm_fwd(*args, **kwargs):
convert_forward(model, module.Phi3Attention, attention_forward)
from ipex_llm.transformers.models.phi3 import mlp_forward
convert_forward(model, module.Phi3MLP, mlp_forward)
from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward
convert_forward(model, module.Phi3RMSNorm, phi3_rms_norm_forward)
from ipex_llm.transformers.models.common import rms_norm_forward
convert_forward(model, module.Phi3RMSNorm, rms_norm_forward)
if model.config.model_type == "phi3":
from ipex_llm.transformers.models.phi3 import phi3_model_forward_wrapper
model_forward = phi3_model_forward_wrapper(module.Phi3Model.forward)
Expand Down
11 changes: 8 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,13 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
key = repeat_kv(key, n_heads // n_kv_heads)
value = repeat_kv(value, n_heads // n_kv_heads)

attn_output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, mask, is_causal=is_causal, scale=scale
)
if is_causal and mask is None:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, is_causal=is_causal, scale=scale
)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, mask, scale=scale
)
attn_output = attn_output.to(dtype) # workaround ipex 2.1's bug
return attn_output
Loading
Loading