diff --git a/merlin/models/tf/__init__.py b/merlin/models/tf/__init__.py index d18411881a..cf8ecbd1df 100644 --- a/merlin/models/tf/__init__.py +++ b/merlin/models/tf/__init__.py @@ -158,6 +158,7 @@ from merlin.models.tf.transforms.sequence import ( ReplaceMaskedEmbeddings, SequenceMaskLast, + SequenceMaskLastInference, SequenceMaskRandom, SequencePredictLast, SequencePredictNext, diff --git a/merlin/models/tf/transformers/block.py b/merlin/models/tf/transformers/block.py index e2259ede9c..ed55f9aac0 100644 --- a/merlin/models/tf/transformers/block.py +++ b/merlin/models/tf/transformers/block.py @@ -91,7 +91,7 @@ def __init__( self.transformer = get_tf_main_layer(transformer) else: self.transformer = transformer - + self.transformer.supports_masking = True if "transformer" in inspect.signature(transformer_pre.__init__).parameters: transformer_pre = transformer_pre(transformer=self.transformer) self.transformer_pre = transformer_pre diff --git a/merlin/models/tf/transformers/transforms.py b/merlin/models/tf/transformers/transforms.py index d0739dec9f..ace3fce40b 100644 --- a/merlin/models/tf/transformers/transforms.py +++ b/merlin/models/tf/transformers/transforms.py @@ -37,10 +37,66 @@ class LastHiddenState(Layer): The output class returned by the HuggingFace transformer layer """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.supports_masking = True + def call(self, inputs: TFBaseModelOutputWithPoolingAndCrossAttentions): return inputs.last_hidden_state +@Block.registry.register("inference_hidden_state") +@tf.keras.utils.register_keras_serializable(package="merlin.models") +class TransformerInferenceHiddenState(Layer): + """A post-processing layer to select the hidden state + of the next-item position, during inference. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.supports_masking = True + + def call( + self, + inputs: tf.Tensor, + training: bool = False, + testing: bool = False, + ): + """Select the hidden state of the target (last) position, during inference. + During training or testing, the inputs are returned + without any processing. + + Parameters: + ---------- + inputs: tf.Tensor + The 3-D output tensor returned by the transformer block + training : bool, optional + Flag that indicates whether in training mode, by default True + testing : bool, optional + Flag that indicates whether in evaluation mode, by default True + + Returns + ------- + tf.Tensor + If inference, returns a 2-D tensor with the hidden states of + the target position + """ + batch_size = tf.shape(inputs)[0] + if not training and not testing: + if getattr(inputs, "_keras_mask", None) is not None: + inputs = tf.reshape( + tf.boolean_mask(inputs, inputs._keras_mask), (-1, inputs.shape[-1]) + ) + tf.debugging.assert_equal( + tf.shape(inputs)[0], + batch_size, + f"The resulting tensor has {tf.shape(inputs)[0]} rows, which does not match" + f" the inputs batch-size {batch_size}. During inference only one position " + "candidate (the last one) should be masked per example", + ) + return inputs + + @Block.registry.register("pooler_output") @tf.keras.utils.register_keras_serializable(package="merlin.models") class PoolerOutput(Layer): @@ -113,10 +169,21 @@ def call(self, inputs: TFBaseModelOutputWithPoolingAndCrossAttentions): class PrepareTransformerInputs(tf.keras.layers.Layer): """Prepare the dictionary of inputs expected by the transformer layer""" + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.supports_masking = True + def call(self, inputs: tf.Tensor) -> Dict[str, tf.Tensor]: + mask = None + if getattr(inputs, "_keras_mask", None) is not None and isinstance( + inputs._keras_mask, tf.RaggedTensor + ): + mask = inputs._keras_mask.to_tensor() if isinstance(inputs, tf.RaggedTensor): # convert to a dense tensor as HF transformers do not support ragged tensors inputs = inputs.to_tensor() + if mask is not None: + inputs._keras_mask = mask return {"inputs_embeds": inputs} diff --git a/merlin/models/tf/transforms/sequence.py b/merlin/models/tf/transforms/sequence.py index 8350a948d7..aad51da98d 100644 --- a/merlin/models/tf/transforms/sequence.py +++ b/merlin/models/tf/transforms/sequence.py @@ -621,22 +621,67 @@ def from_config(cls, config): return cls(schema, target, **config) +@tf.keras.utils.register_keras_serializable(package="merlin.models") +class SequenceMaskLastInference(Block): + def call(self, inputs, training=False, testing=False): + self.inference_mode = not training and not testing + if self.inference_mode: + # Extending sequences in one position by copying the last embedding + repeat = inputs[:, -1:, :] + # repeat = tf.expand_dims(repeat, 1) + inputs = tf.concat([inputs, repeat], axis=1) + return inputs + + def compute_mask(self, inputs, mask=None): + """Selects (masks) the nex position after the + last valid (non-padded) position of the sequential targets + to be predicted. + This method is called by Keras after call() + and returns the mask that is going to be assigned + to the input tensors, being accessible + by tensor._keras_mask + """ + + targets_mask = None + if self.inference_mode: + if isinstance(inputs, tf.RaggedTensor): + row_lengths = inputs.row_lengths(1) + 1 + max_seq_length = tf.cast(tf.reduce_max(row_lengths), tf.int32) + + padding_mask = tf.sequence_mask(row_lengths) + targets_mask = tf.ragged.boolean_mask( + tf.cast(tf.one_hot(row_lengths - 1, max_seq_length), tf.bool), padding_mask + ) + + return targets_mask + + @tf.keras.utils.register_keras_serializable(package="merlin.models") class ReplaceMaskedEmbeddings(Block): """Takes a 3D input tensor (batch size x seq. length x embedding dim) and replaces - by a dummy trainable single embedding at the positions to be masked. - This block looks for the Keras mask (`._keras_mask`) in the following order: - 1. Checks if the input tensor has a mask - 2. Checks if there is a single target and if it has a mask - 3. If there are multiple targets (dict) returns the mask of the target - that matches the first 2 dims of the input - This is useful to be used when PredictMasked() transformation is used in - the Loader, which randomly selects some targets to be predicted and uses - Keras Masking to cascade the `_keras_mask`. By replacing input embeddings - at masked positions we avoid target leakage when training models with - Masked Language Modeling (BERT-like) + by a dummy trainable single embedding at the positions to be masked. + This block looks for the Keras mask (`._keras_mask`) in the following order: + 1. Checks if the input tensor has a mask + 2. Checks if there is a single target and if it has a mask + 3. If there are multiple targets (dict) returns the mask of the target + that matches the first 2 dims of the input + This is useful to be used when PredictMasked() transformation is used in + the Loader, which randomly selects some targets to be predicted and uses + Keras Masking to cascade the `_keras_mask`. By replacing input embeddings + at masked positions we avoid target leakage when training models with + Masked Language Modeling (BERT-like) + + **Note:** To support inference, the input sequence and its corresponding mask should be + extended by one position at the end to account for the next-item (`target`) position. + To do this, you should set `SequenceMaskLastInference` as a pre-layer of + `ReplaceMaskedEmbeddings()` using the sequential-block: + ```mm.SequentialBlock([mm.SequenceMaskLastInference(), mm.ReplaceMaskedEmbeddings()])``` """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.supports_masking = True + def build(self, input_shape): self.hidden_size = input_shape[-1] if self.hidden_size is None: diff --git a/tests/unit/tf/transformers/test_block.py b/tests/unit/tf/transformers/test_block.py index 8c86abab55..67ee4dfff9 100644 --- a/tests/unit/tf/transformers/test_block.py +++ b/tests/unit/tf/transformers/test_block.py @@ -284,7 +284,13 @@ def test_transformer_with_masked_language_modeling(sequence_testing_data: Datase seq_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None ), ), - BertBlock(d_model=48, n_head=8, n_layer=2, pre=mm.ReplaceMaskedEmbeddings()), + BertBlock( + d_model=48, + n_head=8, + n_layer=2, + pre=mm.SequentialBlock([mm.SequenceMaskLastInference(), mm.ReplaceMaskedEmbeddings()]), + post="inference_hidden_state", + ), mm.CategoricalOutput( seq_schema.select_by_name(target), default_loss="categorical_crossentropy", @@ -308,10 +314,9 @@ def test_transformer_with_masked_language_modeling(sequence_testing_data: Datase metrics = model.evaluate(loader, batch_size=8, steps=1, return_dict=True, pre=seq_mask_last) assert len(metrics) > 0 + # Get predictions for next-item position predictions = model.predict(loader, batch_size=8, steps=1) - # TODO: Decide what should be the output of predictions for MLM (currently it predicts for all - # positions of the sequence, but typically you want a single next-item prediction) - assert predictions.shape == (8, 4, 51997) + assert predictions.shape == (8, 51997) @pytest.mark.parametrize("run_eagerly", [True, False]) diff --git a/tests/unit/tf/transforms/test_sequence.py b/tests/unit/tf/transforms/test_sequence.py index 5b2abe6fc9..05f8011751 100644 --- a/tests/unit/tf/transforms/test_sequence.py +++ b/tests/unit/tf/transforms/test_sequence.py @@ -252,7 +252,7 @@ def test_replace_masked_input_embeddings_no_target(): targets = None masked_embeddings = mm.ReplaceMaskedEmbeddings() - output = masked_embeddings(item_id_emb_seq, targets=targets) + output = masked_embeddings(item_id_emb_seq, targets=targets, training=True) # Checks that no input embedding was replaced, as there was no masking defined tf.Assert(tf.logical_not(tf.reduce_all(output == item_id_emb_seq)), []) @@ -262,7 +262,7 @@ def test_not_replace_unmasked_sequence_embeddings(): targets = tf.random.uniform((8, 10), dtype=tf.float32) masked_embeddings = mm.ReplaceMaskedEmbeddings() - output = masked_embeddings(item_id_emb_seq, targets=targets) + output = masked_embeddings(item_id_emb_seq, targets=targets, training=True) # Checks that no input embedding was replaced, as there was no masking defined tf.Assert(tf.reduce_all(output == item_id_emb_seq), []) @@ -275,7 +275,7 @@ def test_replace_masked_input_2d_embeddings_incompatible_2d_mask(): masked_embeddings = mm.ReplaceMaskedEmbeddings() with pytest.raises(Exception) as exc_info: - _ = masked_embeddings(item_id_emb_seq) + _ = masked_embeddings(item_id_emb_seq, training=True) assert "The inputs and mask need to be compatible" in str(exc_info.value) @@ -287,7 +287,7 @@ def test_replace_masked_input_2d_embeddings_incompatible_ragged_2d_mask(): masked_embeddings = mm.ReplaceMaskedEmbeddings() with pytest.raises(Exception) as exc_info: - _ = masked_embeddings(item_id_emb_seq) + _ = masked_embeddings(item_id_emb_seq, training=True) assert "The inputs and mask need to be compatible" in str(exc_info.value)