Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the inference of transformer-based models trained with masked language modeling #909

Merged
merged 9 commits into from
Dec 13, 2022
1 change: 1 addition & 0 deletions merlin/models/tf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
from merlin.models.tf.transforms.sequence import (
ReplaceMaskedEmbeddings,
SequenceMaskLast,
SequenceMaskLastInference,
SequenceMaskRandom,
SequencePredictLast,
SequencePredictNext,
Expand Down
2 changes: 1 addition & 1 deletion merlin/models/tf/transformers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
self.transformer = get_tf_main_layer(transformer)
else:
self.transformer = transformer

self.transformer.supports_masking = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set supports_masking=True so that the transformer layer can forward the mask tensor created by SequenceMaskLastInference

if "transformer" in inspect.signature(transformer_pre.__init__).parameters:
transformer_pre = transformer_pre(transformer=self.transformer)
self.transformer_pre = transformer_pre
Expand Down
59 changes: 59 additions & 0 deletions merlin/models/tf/transformers/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,58 @@ class LastHiddenState(Layer):
The output class returned by the HuggingFace transformer layer
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True

def call(self, inputs: TFBaseModelOutputWithPoolingAndCrossAttentions):
return inputs.last_hidden_state


@Block.registry.register("inference_hidden_state")
@tf.keras.utils.register_keras_serializable(package="merlin.models")
class InferenceHiddenState(Layer):
"""A post-processing layer to select the hidden state
of the next-item position, during inference.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True

def call(
self,
inputs: tf.Tensor,
training: bool = False,
testing: bool = False,
):
"""Select the hidden state of the target position, during inference.
During training or testing, the inputs are returned
without any processing.

Parameters:
----------
inputs: tf.Tensor
The 3-D output tensor returned by the transformer block
training : bool, optional
Flag that indicates whether in training mode, by default True
testing : bool, optional
Flag that indicates whether in evaluation mode, by default True

Returns
-------
tf.Tensor
If inference, returns a 2-D tensor with the hidden states of
the target position
"""
if not training and not testing:
if getattr(inputs, "_keras_mask", None) is not None:
inputs = tf.reshape(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can add an assert here to check if the resulting inputs after boolean_mask() has the 1st dim as before. Because we are trusting here that during inference only one position item (the last one) is masked per example, otherwise the batch size will differ after this conversion. Adding such assert should help users to understand this is an assumption when using this block.

tf.boolean_mask(inputs, inputs._keras_mask), (-1, inputs.shape[-1])
)
return inputs


@Block.registry.register("pooler_output")
@tf.keras.utils.register_keras_serializable(package="merlin.models")
class PoolerOutput(Layer):
Expand Down Expand Up @@ -113,10 +161,21 @@ def call(self, inputs: TFBaseModelOutputWithPoolingAndCrossAttentions):
class PrepareTransformerInputs(tf.keras.layers.Layer):
"""Prepare the dictionary of inputs expected by the transformer layer"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True

def call(self, inputs: tf.Tensor) -> Dict[str, tf.Tensor]:
mask = None
if getattr(inputs, "_keras_mask", None) is not None and isinstance(
inputs._keras_mask, tf.RaggedTensor
):
mask = inputs._keras_mask.to_tensor()
if isinstance(inputs, tf.RaggedTensor):
# convert to a dense tensor as HF transformers do not support ragged tensors
inputs = inputs.to_tensor()
if mask is not None:
inputs._keras_mask = mask
return {"inputs_embeds": inputs}


Expand Down
67 changes: 56 additions & 11 deletions merlin/models/tf/transforms/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,22 +621,67 @@ def from_config(cls, config):
return cls(schema, target, **config)


@tf.keras.utils.register_keras_serializable(package="merlin.models")
class SequenceMaskLastInference(Block):
def call(self, inputs, training=False, testing=False):
self.inference_mode = not training and not testing
if self.inference_mode:
# Extending sequences in one position by copying the last embedding
repeat = inputs[:, -1:, :]
# repeat = tf.expand_dims(repeat, 1)
inputs = tf.concat([inputs, repeat], axis=1)
return inputs

def compute_mask(self, inputs, mask=None):
"""Selects (masks) the nex position after the
last valid (non-padded) position of the sequential targets
to be predicted.
This method is called by Keras after call()
and returns the mask that is going to be assigned
to the input tensors, being accessible
by tensor._keras_mask
"""

targets_mask = None
if self.inference_mode:
if isinstance(inputs, tf.RaggedTensor):
row_lengths = inputs.row_lengths(1) + 1
max_seq_length = tf.cast(tf.reduce_max(row_lengths), tf.int32)

padding_mask = tf.sequence_mask(row_lengths)
targets_mask = tf.ragged.boolean_mask(
tf.cast(tf.one_hot(row_lengths - 1, max_seq_length), tf.bool), padding_mask
)

return targets_mask


@tf.keras.utils.register_keras_serializable(package="merlin.models")
class ReplaceMaskedEmbeddings(Block):
"""Takes a 3D input tensor (batch size x seq. length x embedding dim) and replaces
by a dummy trainable single embedding at the positions to be masked.
This block looks for the Keras mask (`._keras_mask`) in the following order:
1. Checks if the input tensor has a mask
2. Checks if there is a single target and if it has a mask
3. If there are multiple targets (dict) returns the mask of the target
that matches the first 2 dims of the input
This is useful to be used when PredictMasked() transformation is used in
the Loader, which randomly selects some targets to be predicted and uses
Keras Masking to cascade the `_keras_mask`. By replacing input embeddings
at masked positions we avoid target leakage when training models with
Masked Language Modeling (BERT-like)
by a dummy trainable single embedding at the positions to be masked.
This block looks for the Keras mask (`._keras_mask`) in the following order:
1. Checks if the input tensor has a mask
2. Checks if there is a single target and if it has a mask
3. If there are multiple targets (dict) returns the mask of the target
that matches the first 2 dims of the input
This is useful to be used when PredictMasked() transformation is used in
the Loader, which randomly selects some targets to be predicted and uses
Keras Masking to cascade the `_keras_mask`. By replacing input embeddings
at masked positions we avoid target leakage when training models with
Masked Language Modeling (BERT-like)

**Note:** To support inference, the input sequence and its corresponding mask should be
extended by one position at the end to account for the next-item (`target`) position.
To do this, you should set `SequenceMaskLastInference` as a pre-layer of
`ReplaceMaskedEmbeddings()` using the sequential-block:
```mm.SequentialBlock([mm.SequenceMaskLastInference(), mm.ReplaceMaskedEmbeddings()])```
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True

def build(self, input_shape):
self.hidden_size = input_shape[-1]
if self.hidden_size is None:
Expand Down
13 changes: 9 additions & 4 deletions tests/unit/tf/transformers/test_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,13 @@ def test_transformer_with_masked_language_modeling(sequence_testing_data: Datase
seq_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None
),
),
BertBlock(d_model=48, n_head=8, n_layer=2, pre=mm.ReplaceMaskedEmbeddings()),
BertBlock(
d_model=48,
n_head=8,
n_layer=2,
pre=mm.SequentialBlock([mm.SequenceMaskLastInference(), mm.ReplaceMaskedEmbeddings()]),
post="inference_hidden_state",
),
mm.CategoricalOutput(
seq_schema.select_by_name(target),
default_loss="categorical_crossentropy",
Expand All @@ -304,10 +310,9 @@ def test_transformer_with_masked_language_modeling(sequence_testing_data: Datase
metrics = model.evaluate(loader, batch_size=8, steps=1, return_dict=True, pre=seq_mask_last)
assert len(metrics) > 0

# Get predictions for next-item position
predictions = model.predict(loader, batch_size=8, steps=1)
# TODO: Decide what should be the output of predictions for MLM (currently it predicts for all
# positions of the sequence, but typically you want a single next-item prediction)
assert predictions.shape == (8, 4, 51997)
assert predictions.shape == (8, 51997)


@pytest.mark.parametrize("run_eagerly", [True, False])
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/tf/transforms/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def test_replace_masked_input_embeddings_no_target():
targets = None

masked_embeddings = mm.ReplaceMaskedEmbeddings()
output = masked_embeddings(item_id_emb_seq, targets=targets)
output = masked_embeddings(item_id_emb_seq, targets=targets, training=True)
# Checks that no input embedding was replaced, as there was no masking defined
tf.Assert(tf.logical_not(tf.reduce_all(output == item_id_emb_seq)), [])

Expand All @@ -284,7 +284,7 @@ def test_not_replace_unmasked_sequence_embeddings():
targets = tf.random.uniform((8, 10), dtype=tf.float32)

masked_embeddings = mm.ReplaceMaskedEmbeddings()
output = masked_embeddings(item_id_emb_seq, targets=targets)
output = masked_embeddings(item_id_emb_seq, targets=targets, training=True)
# Checks that no input embedding was replaced, as there was no masking defined
tf.Assert(tf.reduce_all(output == item_id_emb_seq), [])

Expand All @@ -297,7 +297,7 @@ def test_replace_masked_input_2d_embeddings_incompatible_2d_mask():
masked_embeddings = mm.ReplaceMaskedEmbeddings()

with pytest.raises(Exception) as exc_info:
_ = masked_embeddings(item_id_emb_seq)
_ = masked_embeddings(item_id_emb_seq, training=True)
assert "The inputs and mask need to be compatible" in str(exc_info.value)


Expand All @@ -309,7 +309,7 @@ def test_replace_masked_input_2d_embeddings_incompatible_ragged_2d_mask():
masked_embeddings = mm.ReplaceMaskedEmbeddings()

with pytest.raises(Exception) as exc_info:
_ = masked_embeddings(item_id_emb_seq)
_ = masked_embeddings(item_id_emb_seq, training=True)
assert "The inputs and mask need to be compatible" in str(exc_info.value)


Expand Down