Skip to content

Commit

Permalink
Merge pull request #77 from hx89:checkpoint_fix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 647034147
  • Loading branch information
pax authors committed Jun 26, 2024
2 parents 7293a44 + 4878086 commit f0149a3
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
3 changes: 3 additions & 0 deletions praxis/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1698,6 +1698,9 @@ def __call__(
query_proj = self.query(query_vec)
key_proj = self.key(key_vec)
value_proj = self.value(value_vec)
query_proj = checkpoint_name(query_proj, 'query_proj')
key_proj = checkpoint_name(key_proj, 'key_proj')
value_proj = checkpoint_name(value_proj, 'value_proj')

if not self.consolidate_rope_key_state:
self._fprop_update_decode_state('key_state', key_proj)
Expand Down
1 change: 1 addition & 0 deletions praxis/layers/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,5 +350,6 @@ def GrokUniTransformerLmHParams(
num_pipeline_stages=num_pipeline_stages,
num_pipeline_microbatches=num_pipeline_microbatches,
stream_io=True,
checkpoint_policy=checkpoint_policy,
)
return p

0 comments on commit f0149a3

Please sign in to comment.