Skip to content

Commit a8f375c

Browse files
committed
rebased. fixed two tests. make style
1 parent 7288d12 commit a8f375c

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

tests/transformers/tests/models/falcon/test_modeling_falcon.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -454,11 +454,12 @@ def test_model_rope_scaling(self):
454454
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :])
455455
# Sanity check linear RoPE scaling
456456
# New position "x" should match original position with index "x/scaling_factor"
457-
linear_scaling_rope = FalconLinearScalingRotaryEmbedding(
457+
linear_scaling_rope = FalconRotaryEmbedding(
458458
head_dim,
459459
max_position_embeddings=config.max_position_embeddings,
460460
base=config.rope_theta,
461461
scaling_factor=scaling_factor,
462+
rope_type="linear",
462463
).to(torch_device)
463464
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length)
464465
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length)
@@ -471,11 +472,12 @@ def test_model_rope_scaling(self):
471472
# Sanity check Dynamic NTK RoPE scaling
472473
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
473474
# with scaling_factor (or that `inv_freq` decreases)
474-
ntk_scaling_rope = FalconDynamicNTKScalingRotaryEmbedding(
475+
ntk_scaling_rope = FalconRotaryEmbedding(
475476
head_dim,
476477
max_position_embeddings=config.max_position_embeddings,
477478
base=config.rope_theta,
478479
scaling_factor=scaling_factor,
480+
rope_type="dynamic",
479481
).to(torch_device)
480482
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length)
481483
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length)

tests/transformers/tests/models/gpt_neox/test_modeling_gpt_neox.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -367,11 +367,12 @@ def test_model_rope_scaling(self):
367367
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :])
368368
# Sanity check linear RoPE scaling
369369
# New position "x" should match original position with index "x/scaling_factor"
370-
linear_scaling_rope = GPTNeoXLinearScalingRotaryEmbedding(
370+
linear_scaling_rope = GPTNeoXRotaryEmbedding(
371371
head_dim,
372372
max_position_embeddings=config.max_position_embeddings,
373373
base=config.rotary_emb_base,
374374
scaling_factor=scaling_factor,
375+
rope_type="linear",
375376
).to(torch_device)
376377
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length)
377378
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length)
@@ -384,11 +385,12 @@ def test_model_rope_scaling(self):
384385
# Sanity check Dynamic NTK RoPE scaling
385386
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
386387
# with scaling_factor (or that `inv_freq` decreases)
387-
ntk_scaling_rope = GPTNeoXDynamicNTKScalingRotaryEmbedding(
388+
ntk_scaling_rope = GPTNeoXRotaryEmbedding(
388389
head_dim,
389390
max_position_embeddings=config.max_position_embeddings,
390391
base=config.rotary_emb_base,
391392
scaling_factor=scaling_factor,
393+
rope_type="dynamic",
392394
).to(torch_device)
393395
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length)
394396
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length)

0 commit comments

Comments
 (0)