-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New design of the transformer API to support causal and masked pre-training approach #1008
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sara. Does our new Transformer API already support dense tensors for inputs and targets instead of RaggedTensor?
The dataloader provides dense tensors for sequential features in some cases (as summarized in this ADR):
- In the current dataloader API, if
value_count.max is not None and is_ragged == False
- In the future dataloader API, if
is_ragged == False
# losses does not support RaggedVariantTensor on GPU: | ||
prediction = prediction.flat_values | ||
if isinstance(target, tf.RaggedTensor): | ||
target = target.flat_values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As you are flattening the values here to 1D, is there a way to reshape the losses output back to be a RaggedTensor? Otherwise the 1D loss will not match the sample weights, that can be either 1D or 2D (ragged).
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
a02d8d1
to
20a40d7
Compare
Thank you for checking out the PR! This PR only addresses how to support the different masking approaches in the TransformerBlock, but we still need to work on extending the SequenceTransforms to support dense tensors as inputs (as mentioned in this ADR). |
5b5ef80
to
a92bdc2
Compare
closing as this was a placeholder for the tutorial image. |
This is a placeholder to support the Transformer-API for the GTC tutorial 2023. This branch is rebased with release-23.02.
For the latest work intended to be merged with the main branch, please refer to #1022