Skip to content

Commit

Permalink
Refactor mixtral moe block. (#1635)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkk12014402 authored and regisss committed Dec 23, 2024
1 parent 544eff7 commit f91946f
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 12 deletions.
4 changes: 2 additions & 2 deletions examples/trl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ $ pip install -U -r requirements.txt
--use_flash_attention
```
2. Supervised fine-tuning of the mistralai/Mixtral-8x7B-v0.1 on 4 cards:
2. Supervised fine-tuning of the mistralai/Mixtral-8x7B-Instruct-v0.1 on 4 cards:
```
DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 python ../gaudi_spawn.py --world_size 4 --use_deepspeed sft.py \
--model_name_or_path mistralai/Mixtral-8x7B-v0.1 \
--model_name_or_path mistralai/Mixtral-8x7B-Instruct-v0.1 \
--dataset_name "philschmid/dolly-15k-oai-style" \
--subset 'data/' \
--streaming False \
Expand Down
18 changes: 8 additions & 10 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import accelerate
import transformers
Expand Down Expand Up @@ -214,6 +213,7 @@
gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
gaudi_mistral_rmsnorm_forward,
gaudi_mixtral_block_dynamic_moe_forward,
gaudi_mixtral_block_moe_forward,
gaudi_mixtral_block_sparse_moe_forward,
gaudi_mixtral_rmsnorm_forward,
gaudi_opt_attention_forward,
Expand Down Expand Up @@ -555,15 +555,13 @@ def adapt_transformers_to_gaudi():
transformers.models.mixtral.modeling_mixtral.MixtralAttention = GaudiMixtralAttention
transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM = GaudiMixtralForCausalLM
transformers.models.mixtral.modeling_mixtral.MixtralModel = GaudiMixtralModel
# We need this workaround until moe op in hpu is supporting fp8
if os.environ.get("QUANT_CONFIG"):
transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = (
gaudi_mixtral_block_sparse_moe_forward
)
else:
transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = (
gaudi_mixtral_block_dynamic_moe_forward
)
transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.sparse_moe_forward = (
gaudi_mixtral_block_sparse_moe_forward
)
transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.dynamic_moe_forward = (
gaudi_mixtral_block_dynamic_moe_forward
)
transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = gaudi_mixtral_block_moe_forward
transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer = GaudiMixtralDecoderLayer
transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm.forward = gaudi_mixtral_rmsnorm_forward
transformers.models.mixtral.configuration_mixtral.MixtralConfig = MixtralConfig
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@
GaudiMixtralModel,
MixtralConfig,
gaudi_mixtral_block_dynamic_moe_forward,
gaudi_mixtral_block_moe_forward,
gaudi_mixtral_block_sparse_moe_forward,
gaudi_mixtral_rmsnorm_forward,
)
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/mixtral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
GaudiMixtralForCausalLM,
GaudiMixtralModel,
gaudi_mixtral_block_dynamic_moe_forward,
gaudi_mixtral_block_moe_forward,
gaudi_mixtral_block_sparse_moe_forward,
gaudi_mixtral_rmsnorm_forward,
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import contextlib
import math
import os
from typing import List, Optional, Tuple, Union

import habana_frameworks.torch.core as htcore
Expand Down Expand Up @@ -357,6 +358,14 @@ def forward(
return attn_output, attn_weights, past_key_value


def gaudi_mixtral_block_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# We need this workaround until moe op in hpu is supporting fp8
if not self.training and not os.environ.get("QUANT_CONFIG"):
return self.dynamic_moe_forward(hidden_states)

return self.sparse_moe_forward(hidden_states)


def gaudi_mixtral_block_sparse_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Copied from MixtralSparseMoeBlock.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py
Expand Down

0 comments on commit f91946f

Please sign in to comment.