-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix BroadcastToSequence to enable context features in sequential models #991
Conversation
…fixed size first dim in graph mode and not being compatible with the ragged sequential features
… with last dim undefined (which happens in graph mode)
Documentation preview |
@@ -101,6 +101,13 @@ def call(self, inputs: TabularData, **kwargs) -> TabularData: | |||
elif isinstance(val, tf.RaggedTensor): | |||
ragged = val | |||
else: | |||
# Expanding / setting last dim of non-list features to be 1D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this relevant to ProcessList
. intuitively ProcessList
sounds like it might only be transforming list features
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a good point @oliverholworthy . The ProcessList
is currently used at core places to ensure the features are in good shape for models. When the change making dataloader outputs scalars as 1D happens, we will also need this fix that makes scalars 2D (batch size, 1) for models.
What if we rename ProcessList
to PrepareFeatures
and have it as a generic block that works as a translation layer between dataloader and models (used in the same places ProcessList
is currently used)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rename sounds like a reasonable thing to do and better matches it's purpose. Can be in another PR if preferred
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@oliverholworthy I have created another separate PR just for renaming ProcessList
to PrepareFeatures
: PR #992
…ls (#991) * Fixed error that was causing the broadcasted context feature to have fixed size first dim in graph mode and not being compatible with the ragged sequential features * Enforcing non-list (scalar) features to be 2D (batch size,1) if 1D or with last dim undefined (which happens in graph mode) * Making Continuous support_masking=True (to cascade mask) * Changing BroadcastToSequence to fix some issues and simplify the masking * Fixed tests * Fixed test
Fixes #989
Goals ⚽
This PR fixes
BroadcastToSequence
, which was not working properly in some casesImplementation Details 🚧
BroadcastToSequence
was not working in graph mode because the contextual features were being expanded to match the shape of sequential features using a logic like the following snippet, which was making the first dim (batch size) of the resulting tensor fixed, while the other sequential tensors had the first dim as None (bcs the batch size might vary).The solution I found to keep the first dim of the expanded context feature tensor matching the other sequential features was using
tf.ones_like()
to create a (ragged) tensor matching the sequential feature shape and then multiply by the context feature (equivalent to repeat it), as the following exampleThis PR also includes other fixes:
Continuous
block now forwards masking (supports_masking=True
)BroadcastToSequence
was refactored to separate the logic that checks if sequential and context features in the input match the schema, otherwise exceptions are raised now.BroadcastToSequence.compute_mask()
I don't callself._broadcast()
and the logic there was made simpler: the expanded contextual feature mask should match the sequential feature mask.SequenceTargetAsInput
now returns atuple
instead of aPrediction
, so that Keras can better align theinputs
andtargets
output fromcall()
andcompute_mask()
of child classes.ProcessList
, reshaping scalar features (is_list=False,is_ragged=False) to be 2D (batch size, 1), as the last dim was None in graph mode and was causing issues when concatenatingTesting Details 🔍
Have added new tests for testing
BroadcastToSequence
usage as apost
ofInputBlockV2
, with both categorical and continuous context (non-sequential) features and also to test it in a Transformer model trained with masked language modeling.