@@ -367,11 +367,12 @@ def test_model_rope_scaling(self):
367
367
torch .testing .assert_close (original_sin_short , original_sin_long [:short_input_length , :])
368
368
# Sanity check linear RoPE scaling
369
369
# New position "x" should match original position with index "x/scaling_factor"
370
- linear_scaling_rope = GPTNeoXLinearScalingRotaryEmbedding (
370
+ linear_scaling_rope = GPTNeoXRotaryEmbedding (
371
371
head_dim ,
372
372
max_position_embeddings = config .max_position_embeddings ,
373
373
base = config .rotary_emb_base ,
374
374
scaling_factor = scaling_factor ,
375
+ rope_type = "linear" ,
375
376
).to (torch_device )
376
377
linear_cos_short , linear_sin_short = linear_scaling_rope (x , short_input_length )
377
378
linear_cos_long , linear_sin_long = linear_scaling_rope (x , long_input_length )
@@ -384,11 +385,12 @@ def test_model_rope_scaling(self):
384
385
# Sanity check Dynamic NTK RoPE scaling
385
386
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
386
387
# with scaling_factor (or that `inv_freq` decreases)
387
- ntk_scaling_rope = GPTNeoXDynamicNTKScalingRotaryEmbedding (
388
+ ntk_scaling_rope = GPTNeoXRotaryEmbedding (
388
389
head_dim ,
389
390
max_position_embeddings = config .max_position_embeddings ,
390
391
base = config .rotary_emb_base ,
391
392
scaling_factor = scaling_factor ,
393
+ rope_type = "dynamic" ,
392
394
).to (torch_device )
393
395
ntk_cos_short , ntk_sin_short = ntk_scaling_rope (x , short_input_length )
394
396
ntk_cos_long , ntk_sin_long = ntk_scaling_rope (x , long_input_length )
0 commit comments