Skip to content

Commit

Permalink
In jnp.reshape, use shape rather than deprecated newshape. The `n…
Browse files Browse the repository at this point in the history
…ewshape` parameter was deprecated in JAX v0.4.28, and will soon be removed.

PiperOrigin-RevId: 695445768
  • Loading branch information
Jake VanderPlas authored and pax authors committed Nov 11, 2024
1 parent 9693c08 commit 899b56e
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions praxis/decoder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,14 @@ def two_stage_topk(
beam_size = hyp_scores.shape[1]
tokens_per_beam = beam_size if tokens_per_beam is None else tokens_per_beam
logits_reshape = jnp.reshape(
logits, newshape=(batch_size * beam_size, vocab_size)
logits, shape=(batch_size * beam_size, vocab_size)
)
topk_value, topk_indices = jax.lax.top_k(logits_reshape, tokens_per_beam)
topk_value = jnp.reshape(
topk_value, newshape=(batch_size, beam_size, tokens_per_beam)
topk_value, shape=(batch_size, beam_size, tokens_per_beam)
)
topk_indices = jnp.reshape(
topk_indices, newshape=(batch_size, beam_size, tokens_per_beam)
topk_indices, shape=(batch_size, beam_size, tokens_per_beam)
)
topk_value += jnp.expand_dims(hyp_scores, -1)
for terminal_id in terminal_ids:
Expand All @@ -176,10 +176,10 @@ def two_stage_topk(
)

topk_value = jnp.reshape(
topk_value, newshape=(batch_size, beam_size * tokens_per_beam)
topk_value, shape=(batch_size, beam_size * tokens_per_beam)
)
topk_indices = jnp.reshape(
topk_indices, newshape=(batch_size, beam_size * tokens_per_beam)
topk_indices, shape=(batch_size, beam_size * tokens_per_beam)
)

final_topk_value, final_topk_indices = jax.lax.top_k(topk_value, beam_size)
Expand Down

0 comments on commit 899b56e

Please sign in to comment.