Skip to content

Commit 3a2b9e6

Browse files
committed
added style
1 parent 39e8279 commit 3a2b9e6

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

examples/text-generation/utils.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -179,20 +179,16 @@ def patch_scoped_linear_all_reduce(model):
179179

180180
def get_torch_compiled_model(model, logger):
181181
# for gpt_bigcode, mpt, bloom, gpt2 model_type
182-
if hasattr(model, 'transformer'):
182+
if hasattr(model, "transformer"):
183183
model.transformer = torch.compile(
184184
model.transformer, backend="hpu_backend", options={"keep_input_mutations": True}
185185
)
186186
# for gpt_neox
187-
elif hasattr(model, 'gpt_neox'):
188-
model.gpt_neox = torch.compile(
189-
model.gpt_neox, backend="hpu_backend", options={"keep_input_mutations": True}
190-
)
187+
elif hasattr(model, "gpt_neox"):
188+
model.gpt_neox = torch.compile(model.gpt_neox, backend="hpu_backend", options={"keep_input_mutations": True})
191189
# for llama, mistral, mixtral, qwen2
192-
elif hasattr(model, 'model'):
193-
model.model = torch.compile(
194-
model.model, backend="hpu_backend", options={"keep_input_mutations": True}
195-
)
190+
elif hasattr(model, "model"):
191+
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
196192
else:
197193
logger.warning(
198194
"In low performance case, please explicitly specify a module you want to wrap with `torch.compile`"

0 commit comments

Comments
 (0)