Skip to content

Commit

Permalink
Simplify round up to a multiple of 4: replace loop by aligned number …
Browse files Browse the repository at this point in the history
…calculation

PiperOrigin-RevId: 606415687
  • Loading branch information
rybakov authored and pax authors committed Feb 13, 2024
1 parent 763cc99 commit 3e202b6
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions praxis/gshard_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,7 @@ def _cap_logits(logits):
if expert_capacity_dim < auto_expert_capacity:
expert_capacity_dim = auto_expert_capacity
# Round up to a multiple of 4 to avoid possible padding.
while expert_capacity_dim % 4:
expert_capacity_dim += 1
expert_capacity_dim = ((expert_capacity_dim + 3) // 4) * 4
logging.info(
'Setting expert_capacity_dim=%r (capacity_factor=%r '
'group_size_dim=%r experts_dim=%r)',
Expand Down Expand Up @@ -505,8 +504,7 @@ def _cap_logits(logits):
if expert_capacity_dim < auto_expert_capacity:
expert_capacity_dim = auto_expert_capacity
# Round up to a multiple of 4 to avoid possible padding.
while expert_capacity_dim % 4:
expert_capacity_dim += 1
expert_capacity_dim = ((expert_capacity_dim + 3) // 4) * 4
logging.info(
'Setting expert_capacity_dim=%r (capacity_factor=%r '
'group_size_dim=%r experts_dim=%r)',
Expand Down Expand Up @@ -900,8 +898,7 @@ def _cap_logits(logits):
if expert_capacity_dim < auto_expert_capacity:
expert_capacity_dim = auto_expert_capacity
# Round up to a multiple of 4 to avoid possible padding.
while expert_capacity_dim % 4:
expert_capacity_dim += 1
expert_capacity_dim = ((expert_capacity_dim + 3) // 4) * 4
logging.info(
'Setting expert_capacity_dim=%r (capacity_factor=%r '
'group_size_dim=%r experts_dim=%r)',
Expand Down

0 comments on commit 3e202b6

Please sign in to comment.