-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
New design of the transformer API (#1022)
* implement new design of the Transformer API on top of the release-23.02 branch * add support of ragged tensor to weight tying * update example notebook with the new API * include PR comments * fix masking of sequence-predict-next transform * adjust sample_weights to targets shape * add masking support to SequencePredictRandom transform * rebase with main branch to include data loader changes * fix linting * Fix the adjust-predictions logic to support targets as 2-D scalars * Fix transformer example notebook * update import of transformer blocks in transforms/sequence and move them inside configure_for_train() function.
- Loading branch information
Showing
9 changed files
with
876 additions
and
179 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.