@@ -179,20 +179,16 @@ def patch_scoped_linear_all_reduce(model):
179
179
180
180
def get_torch_compiled_model (model , logger ):
181
181
# for gpt_bigcode, mpt, bloom, gpt2 model_type
182
- if hasattr (model , ' transformer' ):
182
+ if hasattr (model , " transformer" ):
183
183
model .transformer = torch .compile (
184
184
model .transformer , backend = "hpu_backend" , options = {"keep_input_mutations" : True }
185
185
)
186
186
# for gpt_neox
187
- elif hasattr (model , 'gpt_neox' ):
188
- model .gpt_neox = torch .compile (
189
- model .gpt_neox , backend = "hpu_backend" , options = {"keep_input_mutations" : True }
190
- )
187
+ elif hasattr (model , "gpt_neox" ):
188
+ model .gpt_neox = torch .compile (model .gpt_neox , backend = "hpu_backend" , options = {"keep_input_mutations" : True })
191
189
# for llama, mistral, mixtral, qwen2
192
- elif hasattr (model , 'model' ):
193
- model .model = torch .compile (
194
- model .model , backend = "hpu_backend" , options = {"keep_input_mutations" : True }
195
- )
190
+ elif hasattr (model , "model" ):
191
+ model .model = torch .compile (model .model , backend = "hpu_backend" , options = {"keep_input_mutations" : True })
196
192
else :
197
193
logger .warning (
198
194
"In low performance case, please explicitly specify a module you want to wrap with `torch.compile`"
0 commit comments