From 62e05913c896a732d3e10edf6efab123cea4e9e8 Mon Sep 17 00:00:00 2001 From: Sara Rabhi Date: Tue, 21 Mar 2023 13:40:48 -0400 Subject: [PATCH] New design of the transformer API (#1022) * implement new design of the Transformer API on top of the release-23.02 branch * add support of ragged tensor to weight tying * update example notebook with the new API * include PR comments * fix masking of sequence-predict-next transform * adjust sample_weights to targets shape * add masking support to SequencePredictRandom transform * rebase with main branch to include data loader changes * fix linting * Fix the adjust-predictions logic to support targets as 2-D scalars * Fix transformer example notebook * update import of transformer blocks in transforms/sequence and move them inside configure_for_train() function. --- .../transformers-next-item-prediction.ipynb | 42 +- merlin/models/tf/models/base.py | 296 ++++++++++-- merlin/models/tf/outputs/classification.py | 7 +- merlin/models/tf/transformers/block.py | 56 ++- merlin/models/tf/transformers/transforms.py | 52 ++- merlin/models/tf/transforms/sequence.py | 435 +++++++++++++++--- merlin/models/tf/utils/tf_utils.py | 32 ++ tests/unit/tf/transformers/test_block.py | 122 +++-- tests/unit/tf/transforms/test_sequence.py | 13 +- 9 files changed, 876 insertions(+), 179 deletions(-) diff --git a/examples/usecases/transformers-next-item-prediction.ipynb b/examples/usecases/transformers-next-item-prediction.ipynb index 5cd1e8a8c6..e864685a81 100644 --- a/examples/usecases/transformers-next-item-prediction.ipynb +++ b/examples/usecases/transformers-next-item-prediction.ipynb @@ -671,6 +671,25 @@ "seq_schema" ] }, + { + "cell_type": "markdown", + "id": "0a87439c", + "metadata": {}, + "source": [ + "Align the schema of train and validation datasets with the model's schema" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b90424a", + "metadata": {}, + "outputs": [], + "source": [ + "train_set_processed.schema = seq_schema\n", + "validation_set_processed.schema = seq_schema" + ] + }, { "cell_type": "markdown", "id": "8d422833", @@ -724,7 +743,7 @@ "id": "0a460e4c", "metadata": {}, "source": [ - "For the transformer portion of our model, we will use the `XLNet` architecture. Additionally, we are passing `mm.ReplaceMaskedEmbeddings()` as our `pre` block. We will be training a masked language model and this parameter is responsible for the masking of our sequences." + "For the transformer portion of our model, we will use the `XLNet` architecture." ] }, { @@ -732,12 +751,12 @@ "id": "23bf02dc", "metadata": {}, "source": [ - "Later, when we run the `fit` method on our model, we will specify the `masking_probability` of `0.3`. Through the combination of these parameters, our model will train on sequences where any given timestep will be masked with a probability of 0.3 and it will be our model's training task to infer the target value for that step!\n", + "Later, when we run the `fit` method on our model, we will specify the `masking_probability` of `0.3` and link it to the transformer block defined in out model. Through the combination of these parameters, our model will train on sequences where any given timestep will be masked with a probability of 0.3 and it will be our model's training task to infer the target value for that step!\n", "\n", - "To summarize, Masked Language Modeling is implemented by using two blocks in combination:\n", + "To summarize, Masked Language Modeling is implemented by:\n", "\n", - "* `SequenceMaskRandom()` - Used as a pre for model.fit(), it randomly selects items from the sequence to be masked for prediction as targets, by using Keras masking.\n", - "* `ReplaceMaskedEmbeddings()` - Used as a pre for a `TransformerBlock`, it replaces the input embeddings at masked positions for prediction by a dummy trainable embedding, to avoid leakage of the targets.\n", + "* `SequenceMaskRandom()` - Used as a pre for model.fit(), it randomly selects items from the sequence to be masked for prediction as targets, by using Keras masking. This block also adds the necessary configuration to the specified `transformer` block so as it\n", + "is pre-configured with the necessary layers needed to prepare the inputs to the HuggingFace transformer layer and to post-process its outputs. For example, one pre-processing operation is to replace the input embeddings at masked positions for prediction by a dummy trainable embedding, to avoid leakage of the targets.\n", "\n", "\n", "**Read more about the apis used to construct models** \n", @@ -746,7 +765,6 @@ "- [InputBlockV2](https://github.com/NVIDIA-Merlin/models/blob/main/merlin/models/tf/inputs/base.py)\n", "- [Embeddings](https://github.com/NVIDIA-Merlin/models/blob/main/merlin/models/tf/inputs/embedding.py)\n", "- [XLNetBlock](https://github.com/NVIDIA-Merlin/models/blob/main/merlin/models/tf/transformers/block.py)\n", - "- [ReplaceMaskedEmbeddings](https://github.com/NVIDIA-Merlin/models/blob/main/merlin/models/tf/transforms/sequence.py)\n", "- [CategoricalOutput](https://github.com/NVIDIA-Merlin/models/blob/main/merlin/models/tf/outputs/classification.py)\n", "- [.schema.select_by_name](https://github.com/NVIDIA-Merlin/core/blob/main/merlin/schema/schema.py)\n", "- [.schema.select_by_tag](https://github.com/NVIDIA-Merlin/core/blob/main/merlin/schema/schema.py)\n", @@ -770,6 +788,7 @@ " activation='relu',\n", " no_activation_last_layer=True,\n", " )\n", + "transformer_block = mm.XLNetBlock(d_model=dmodel, n_head=4, n_layer=2)\n", "model = mm.Model(\n", " mm.InputBlockV2(\n", " seq_schema,\n", @@ -778,10 +797,7 @@ " ),\n", " ),\n", " mlp_block,\n", - " mm.XLNetBlock(d_model=dmodel, n_head=4, n_layer=2, \n", - " pre=mm.ReplaceMaskedEmbeddings(),\n", - " post=\"inference_hidden_state\",\n", - " ),\n", + " transformer_block,\n", " mm.CategoricalOutput(\n", " train_set_processed.schema.select_by_name(target),\n", " default_loss=\"categorical_crossentropy\",\n", @@ -891,7 +907,7 @@ ], "source": [ "model.compile(run_eagerly=False, optimizer='adam', loss=\"categorical_crossentropy\")\n", - "model.fit(train_set_processed, batch_size=64, epochs=5, pre=mm.SequenceMaskRandom(schema=seq_schema, target=target, masking_prob=0.3))" + "model.fit(train_set_processed, batch_size=64, epochs=5, pre=mm.SequenceMaskRandom(schema=seq_schema, target=target, masking_prob=0.3, transformer=transformer_block))" ] }, { @@ -960,7 +976,7 @@ "model.evaluate(\n", " validation_set_processed,\n", " batch_size=128,\n", - " pre=mm.SequenceMaskLast(schema=seq_schema, target=target),\n", + " pre=mm.SequenceMaskLast(schema=validation_set_processed.schema, target=target, transformer=transformer_block),\n", " return_dict=True\n", ")" ] @@ -1000,7 +1016,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.8.2" } }, "nbformat": 4, diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index 3cfeeb5bbb..2e18b74df6 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -51,6 +51,7 @@ from merlin.models.tf.outputs.contrastive import ContrastiveOutput from merlin.models.tf.prediction_tasks.base import ParallelPredictionBlock, PredictionTask from merlin.models.tf.transforms.features import PrepareFeatures, expected_input_cols_from_schema +from merlin.models.tf.transforms.sequence import SequenceTransform from merlin.models.tf.typing import TabularData from merlin.models.tf.utils.search_utils import find_all_instances_in_layers from merlin.models.tf.utils.tf_utils import ( @@ -779,7 +780,7 @@ def call_train_test( predictions[task.task_name] = task_x sample_weights[task.task_name] = task_sample_weight - self.adjust_predictions_and_targets(predictions, targets) + self.adjust_predictions_and_targets(predictions, targets, sample_weights) if len(predictions) == 1 and len(targets) == 1: predictions = list(predictions.values())[0] @@ -811,68 +812,267 @@ def call_train_test( predictions[task.full_name] = task_x sample_weights[task.full_name] = task_sample_weight - self.adjust_predictions_and_targets(predictions, targets) + self.adjust_predictions_and_targets(predictions, targets, sample_weights) return Prediction(predictions, targets, sample_weights) + def _extract_masked_predictions(self, prediction: TensorLike): + """Extracts the prediction scores corresponding to masked positions (targets). + + This method assumes that the input predictions tensor is 3-D and contains a mask + indicating the positions of the targets. It requires that the mask information has + exactly one masked position per input sequence. The method returns a 2-D dense tensor + containing the prediction score corresponding to each masked position. + + Parameters + ---------- + prediction : TensorLike + A 3-D dense tensor of predictions, with a mask indicating the positions of the targets. + + Returns + ------- + tf.Tensor + A 2-D dense tensor of prediction scores, with one score per input. + + Raises + ------ + ValueError + If the mask does not have exactly one masked position per input sequence. + """ + num_preds_per_example = tf.reduce_sum(tf.cast(prediction._keras_mask, tf.int32), axis=-1) + with tf.control_dependencies( + [ + tf.debugging.assert_equal( + num_preds_per_example, + 1, + message="If targets are scalars (1-D) and predictions are" + " sequential (3-D), the predictions mask should contain" + " one masked position per example", + ) + ] + ): + return tf.boolean_mask(prediction, prediction._keras_mask) + + def _adjust_dense_predictions_and_targets( + self, + prediction: tf.Tensor, + target: TensorLike, + sample_weight: TensorLike, + ): + """Adjusts the dense predictions tensor, the target tensor and sample_weight tensor + to ensure compatibility with most Keras losses and metrics. + + This method applies the following transformations to the target and prediction tensors: + - Converts ragged targets and their masks to dense format. + - Copies the target mask to the prediction mask, if defined. + - If predictions are sequential (3-D) and targets are scalar (1-D), extracts the predictions + at target positions using the predictions mask. + - One-hot encodes targets if their rank is one less than the rank of predictions. + - Ensures that targets have the same shape and dtype as predictions. + + Parameters + ---------- + prediction : tf.Tensor + The prediction tensor as a dense tensor. + target : TensorLike + The target tensor that can be either a dense or ragged tensor. + sample_weight : TensorLike + The sample weight tensor that can be either a dense or ragged tensor. + + Returns: + -------- + A tuple of the adjusted prediction, target, and sample_weight tensors, + with the same dtype and shape. + """ + if isinstance(target, tf.RaggedTensor): + # Converts ragged targets (and ragged mask) to dense + dense_target_mask = None + if getattr(target, "_keras_mask", None) is not None: + dense_target_mask = target._keras_mask.to_tensor() + target = target.to_tensor() + if dense_target_mask is not None: + target._keras_mask = dense_target_mask + + if isinstance(sample_weight, tf.RaggedTensor): + sample_weight = sample_weight.to_tensor() + + if prediction.shape.ndims == 2: + # Removes the mask information as the sequence is summarized into + # a single vector. + prediction._keras_mask = None + + elif getattr(target, "_keras_mask", None) is not None: + # Copies the mask from the targets to the predictions + # because Keras considers the prediction mask in loss + # and metrics computation + if isinstance(target._keras_mask, tf.RaggedTensor): + target._keras_mask = target._keras_mask.to_tensor() + prediction._keras_mask = target._keras_mask + + # Ensuring targets and preds have the same dtype + target = tf.cast(target, prediction.dtype) + + # If targets are scalars (1-D) and predictions are sequential (3-D), + # extract predictions at target position because Keras expects + # predictions and targets to have the same shape. + if getattr(prediction, "_keras_mask", None) is not None: + rank_check = tf.logical_and( + tf.logical_and(tf.rank(target) > 0, tf.shape(target)[-1] == 1), + tf.equal(tf.rank(prediction), 3), + ) + prediction = tf.cond( + rank_check, lambda: self._extract_masked_predictions(prediction), lambda: prediction + ) + + # Ensuring targets are one-hot encoded if they are not + condition = tf.logical_and( + tf.logical_and(tf.rank(target) > 0, tf.shape(target)[-1] == 1), + tf.shape(prediction)[-1] > 1, + ) + target = tf.cond( + condition, + lambda: tf.one_hot( + tf.cast(target, tf.int32), + tf.shape(prediction)[-1], + dtype=prediction.dtype, + ), + lambda: target, + ) + # Makes target shape equal to the predictions tensor, as shape is lost after tf.cond + target = tf.reshape(target, tf.shape(prediction)) + + return prediction, target, sample_weight + + def _adjust_ragged_predictions_and_targets( + self, + prediction: tf.RaggedTensor, + target: TensorLike, + sample_weight: TensorLike, + ): + """Adjusts the prediction (ragged tensor), target and sample weight + to ensure compatibility with most Keras losses and metrics. + + This methods applies the following transformations to the target and prediction tensors: + - Select ragged targets based on the mask information, if defined. + - Remove mask information from the ragged targets and predictions. + - One-hot encode targets if their rank is one less than the rank of predictions. + - Ensure that targets have the same shape and dtype as predictions. + + Parameters + ---------- + prediction : tf.RaggedTensor + The prediction tensor as a ragged tensor. + target : TensorLike + The target tensor that can be either a dense or ragged tensor. + sample_weight : TensorLike + The sample weight tensor that can be either a dense or ragged tensor. + + Returns + ------- + Tuple[tf.Tensor, tf.Tensor] + A tuple containing the adjusted prediction, target and sample_weight tensors. + """ + target_mask = None + if getattr(target, "_keras_mask", None) is not None: + target_mask = target._keras_mask + + if isinstance(target, tf.RaggedTensor) and target_mask is not None: + # Select targets at masked positions and return + # a ragged tensor. + target = tf.ragged.boolean_mask( + target, target_mask.with_row_splits_dtype(target.row_splits.dtype) + ) + + # Ensuring targets and preds have the same dtype + target = tf.cast(target, prediction.dtype) + + # Align sample_weight with the ragged target tensor + if isinstance(target, tf.RaggedTensor) and sample_weight is not None: + if isinstance(sample_weight, tf.RaggedTensor): + # sample_weight is a 2-D tensor, weights in the same sequence are different + if target_mask is not None: + # Select sample weights at masked positions and return a ragged tensor. + sample_weight = tf.ragged.boolean_mask( + sample_weight, + target_mask.with_row_splits_dtype(sample_weight.row_splits.dtype), + ) + else: + # sample_weight is a 1-D tensor, one weight value per sequence + # repeat the weight value for each masked target position + row_lengths = tf.constant(target.row_lengths(), dtype=tf.int64) + sample_weight = tf.repeat(sample_weight, row_lengths) + + # Take the flat values of predictions, targets and sample weihts as Keras + # losses does not support RaggedVariantTensor on GPU: + prediction = prediction.flat_values + if isinstance(target, tf.RaggedTensor): + target = target.flat_values + if isinstance(sample_weight, tf.RaggedTensor): + sample_weight = sample_weight.flat_values + + # Ensuring targets are one-hot encoded if they are not + condition = tf.logical_and( + tf.logical_and(tf.rank(target) > 0, tf.shape(target)[-1] == 1), + tf.shape(prediction)[-1] > 1, + ) + target = tf.cond( + condition, + lambda: tf.one_hot( + tf.cast(target, tf.int32), + tf.shape(prediction)[-1], + dtype=prediction.dtype, + ), + lambda: target, + ) + + # Makes target shape equal to the predictions tensor, as shape is lost after tf.cond + target = tf.reshape(target, tf.shape(prediction)) + + return prediction, target, sample_weight + def adjust_predictions_and_targets( self, predictions: Dict[str, TensorLike], - targets: Optional[Union[tf.Tensor, Dict[str, tf.Tensor]]], + targets: Optional[Union[TensorLike, Dict[str, TensorLike]]], + sample_weights: Optional[Union[TensorLike, Dict[str, TensorLike]]], ): - """Adjusts the predctions and targets, doing the following transformations - if the target is provided: - - Converts ragged targets (and their masks) to dense, so that they are compatible - with most losses and metrics - - Copies the targets mask to predictions mask, if defined - - One-hot encode targets if their tf.rank(targets) == tf.rank(predictions)-1 - - Ensures targets has the same shape and dtype as predicitnos + """Adjusts the predictions and targets to ensure compatibility with + most Keras losses and metrics. + + If the predictions are ragged tensors, `_adjust_ragged_predictions_and_targets` is called, + otherwise `_adjust_dense_predictions_and_targets` is called. Parameters ---------- predictions : Dict[str, TensorLike] - A dict with predictions for the tasks + A dictionary with predictions for the tasks. targets : Optional[Union[tf.Tensor, Dict[str, tf.Tensor]]] - A dict with targets for the tasks + A dictionary with targets for the tasks, or None if targets are not provided. + sample_weights : Optional[Union[tf.Tensor, Dict[str, tf.Tensor]]] + A dictionary with sample weights for the tasks, + or None if sample_weights are not provided. + """ if targets is None: return for k in targets: - # Convert ragged targets (and ragged mask) to dense - if isinstance(targets[k], tf.RaggedTensor): - dense_target_mask = None - if getattr(targets[k], "_keras_mask", None) is not None: - dense_target_mask = targets[k]._keras_mask.to_tensor() - targets[k] = targets[k].to_tensor() - if dense_target_mask is not None: - targets[k]._keras_mask = dense_target_mask - - if getattr(targets[k], "_keras_mask", None) is not None: - # Copies the mask from the targets to the predictions - # because Keras considers the prediction mask in loss - # and metrics computation - predictions[k]._keras_mask = targets[k]._keras_mask - - # Ensuring targets and preds have the same dtype - targets[k] = tf.cast(targets[k], predictions[k].dtype) - - # Ensuring targets are one-hot encoded if they are not - condition = tf.logical_and( - tf.logical_and(tf.rank(targets[k]) > 0, tf.shape(targets[k])[-1] == 1), - tf.shape(predictions[k])[-1] > 1, - ) - targets[k] = tf.cond( - condition, - lambda: tf.one_hot( - tf.cast(targets[k], tf.int32), - tf.shape(predictions[k])[-1], - dtype=predictions[k].dtype, - ), - lambda: targets[k], - ) - # Makes target shape equal to the predictions tensor, as shape is lost after tf.cond - targets[k] = tf.reshape(targets[k], tf.shape(predictions[k])) + if isinstance(predictions[k], tf.RaggedTensor): + ( + predictions[k], + targets[k], + sample_weights[k], + ) = self._adjust_ragged_predictions_and_targets( + predictions[k], targets[k], sample_weights[k] + ) + else: + ( + predictions[k], + targets[k], + sample_weights[k], + ) = self._adjust_dense_predictions_and_targets( + predictions[k], targets[k], sample_weights[k] + ) def train_step(self, data): """Custom train step using the `compute_loss` method.""" @@ -1161,6 +1361,8 @@ def fit( if pre: self._reset_compile_cache() self.train_pre = pre + if isinstance(self.train_pre, SequenceTransform): + self.train_pre.configure_for_train() out = super().fit(**fit_kwargs) @@ -1247,6 +1449,8 @@ def evaluate( if pre: self._reset_compile_cache() self.test_pre = pre + if isinstance(self.test_pre, SequenceTransform): + self.test_pre.configure_for_test() out = super().evaluate( x, diff --git a/merlin/models/tf/outputs/classification.py b/merlin/models/tf/outputs/classification.py index be6352a82e..16e100a2fa 100644 --- a/merlin/models/tf/outputs/classification.py +++ b/merlin/models/tf/outputs/classification.py @@ -328,9 +328,14 @@ def build(self, input_shape): return super().build(input_shape) def call(self, inputs, training=False, **kwargs) -> tf.Tensor: + is_ragged = isinstance(inputs, tf.RaggedTensor) + if is_ragged: + original_inputs = inputs + inputs = inputs.flat_values logits = tf.matmul(inputs, self.table.table.embeddings, transpose_b=True) logits = tf.nn.bias_add(logits, self.bias) - + if is_ragged: + logits = original_inputs.with_flat_values(logits) return logits @property diff --git a/merlin/models/tf/transformers/block.py b/merlin/models/tf/transformers/block.py index 5ba1bca45a..a0062e6c7a 100644 --- a/merlin/models/tf/transformers/block.py +++ b/merlin/models/tf/transformers/block.py @@ -53,7 +53,9 @@ def get_tf_main_layer(hf_model): @tf.keras.utils.register_keras_serializable(package="merlin.models") class TransformerBlock(Block): """ - Class to support HF Transformers for session-based and sequential-based recommendation models. + Base class to support HuggingFace Transformers for session-based and + sequential-based recommendation models. + Parameters ---------- transformer: TransformerBody @@ -71,6 +73,12 @@ class TransformerBlock(Block): A block to use before the main transformer block, by default None post: Optional[Union[str, tf.keras.layers.Layer]] A block to use after the main transformer block, by default None + masking_post: + A block to use to postprocess the output of the transformer block based on + keras mask information, by default None + masking_pre: + A block to use to prepare the inputs to the transformer block based on + keras mask information, by default None """ def __init__( @@ -80,6 +88,8 @@ def __init__( post: Optional[Union[str, tf.keras.layers.Layer]] = None, transformer_pre=PrepareTransformerInputs(), transformer_post: Optional[Union[str, tf.keras.layers.Layer]] = "last_hidden_state", + masking_post: Optional[tf.keras.layers.Layer] = None, + masking_pre: Optional[tf.keras.layers.Layer] = None, **kwargs, ): super().__init__(**kwargs) @@ -107,6 +117,24 @@ def __init__( self.post = post self.pre = pre + self._masking_post = masking_post + self._masking_pre = masking_pre + + @property + def masking_post(self): + return self._masking_post + + @masking_post.setter + def masking_post(self, block): + self._masking_post = block + + @property + def masking_pre(self): + return self._masking_pre + + @masking_pre.setter + def masking_pre(self, block): + self._masking_pre = block def build(self, input_shape=None): """Builds the sequential block @@ -135,6 +163,9 @@ def call(self, inputs: tf.Tensor, **kwargs): @property def to_call_pre(self): + if self.masking_pre: + yield self.masking_pre + if self.pre: yield self.pre @@ -144,6 +175,9 @@ def to_call_pre(self): def to_call_post(self): yield self.transformer_post + if self.masking_post: + yield self.masking_post + if self.post: yield self.post @@ -152,7 +186,15 @@ def get_config(self): config = maybe_serialize_keras_objects( self, config, - ["transformer", "pre", "post", "transformer_pre", "transformer_post"], + [ + "transformer", + "pre", + "post", + "transformer_pre", + "transformer_post", + "masking_pre", + "masking_post", + ], ) return config @@ -160,7 +202,15 @@ def get_config(self): def from_config(cls, config, custom_objects=None): config = maybe_deserialize_keras_objects( config, - ["transformer", "pre", "post", "transformer_pre", "transformer_post"], + [ + "transformer", + "pre", + "post", + "transformer_pre", + "transformer_post", + "masking_post", + "masking_pre", + ], ) output = TransformerBlock(**config) diff --git a/merlin/models/tf/transformers/transforms.py b/merlin/models/tf/transformers/transforms.py index 53d3119931..b071f96ecd 100644 --- a/merlin/models/tf/transformers/transforms.py +++ b/merlin/models/tf/transformers/transforms.py @@ -81,19 +81,26 @@ def call( If inference, returns a 2-D tensor with the hidden states of the target position """ - batch_size = tf.shape(inputs)[0] + if isinstance(inputs, tf.RaggedTensor): + batch_size = tf.shape(inputs.row_lengths())[0] + else: + batch_size = tf.shape(inputs)[0] if not training and not testing: - if getattr(inputs, "_keras_mask", None) is not None: + if isinstance(inputs, tf.RaggedTensor): + inputs = inputs[:, -1:, :] + inputs = tf.squeeze(tf.sparse.to_dense(inputs.to_sparse()), axis=1) + + elif getattr(inputs, "_keras_mask", None) is not None: inputs = tf.reshape( tf.boolean_mask(inputs, inputs._keras_mask), (-1, inputs.shape[-1]) ) - tf.debugging.assert_equal( - tf.shape(inputs)[0], - batch_size, - f"The resulting tensor has {tf.shape(inputs)[0]} rows, which does not match" - f" the inputs batch-size {batch_size}. During inference only one position " - "candidate (the last one) should be masked per example", - ) + tf.debugging.assert_equal( + tf.shape(inputs)[0], + batch_size, + f"The resulting tensor has {tf.shape(inputs)[0]} rows, which does not match" + f" the inputs batch-size {batch_size}. During inference only one position " + "candidate (the last one) should be masked per example", + ) return inputs @@ -232,3 +239,30 @@ def __init__(self, initializer_range: float = 0.02, **kwargs): class SequenceClsIndex(SequenceSummary): def __init__(self, initializer_range: float = 0.02, **kwargs): super().__init__("cls_index", initializer_range=initializer_range, **kwargs) + + +@tf.keras.utils.register_keras_serializable(package="merlin.models") +class TransformerOutputToRagged(Block): + """Converts the dense outputs returned by the transformer layer to + a ragged tensor using masking information. + + This layer takes dense inputs from the transformer layer and + applies the masking information (in the `_keras_mask` attribute) + to produce a ragged tensor output. The resulting tensor contains predictions + only at masked positions (targets). + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.supports_masking = True + + def call(self, inputs: tf.Tensor) -> Dict[str, tf.Tensor]: + if isinstance(inputs, tf.RaggedTensor): + return input + + if getattr(inputs, "_keras_mask", None) is not None: + mask = inputs._keras_mask + if isinstance(mask, tf.RaggedTensor): + mask = mask.to_tensor() + inputs = tf.ragged.boolean_mask(inputs, mask) + return inputs diff --git a/merlin/models/tf/transforms/sequence.py b/merlin/models/tf/transforms/sequence.py index ff7bdd28fd..95e42916b1 100644 --- a/merlin/models/tf/transforms/sequence.py +++ b/merlin/models/tf/transforms/sequence.py @@ -18,12 +18,14 @@ import tensorflow as tf from tensorflow.keras.backend import random_bernoulli +from merlin.models.tf.core import combinators from merlin.models.tf.core.base import Block, BlockType, PredictionOutput from merlin.models.tf.core.combinators import TabularBlock from merlin.models.tf.transforms.features import PrepareFeatures from merlin.models.tf.typing import TabularData from merlin.models.tf.utils import tf_utils from merlin.models.utils import schema_utils +from merlin.models.utils.dependencies import is_transformers_available from merlin.schema import ColumnSchema, Schema, Tags @@ -88,6 +90,9 @@ class SequenceTransform(TabularBlock): P.s. The PrepareFeatures() block is applied to convert the tuple representation of sequential features to RaggedTensors, so that the tensors sequences can be shifted/truncated + transformer: Optional[TransformerBlock] + The transformer block that leverages the group of sequences returned + by the given SequenceTransform, by default None. """ def __init__( @@ -95,6 +100,7 @@ def __init__( schema: Schema, target: Union[str, Tags, ColumnSchema], pre: Optional[BlockType] = None, + transformer=None, **kwargs, ): _pre = PrepareFeatures(schema) @@ -104,6 +110,7 @@ def __init__( self.target = target self.target_name = self._get_target(target) + self.transformer = transformer def _get_target(self, target): if ( @@ -189,6 +196,18 @@ def from_config(cls, config): target = config.pop("target") return cls(schema, target, **config) + def configure_for_train(self): + """Method called by the model.fit() to set additional model's + configuration before calling keras parent class `fit()` + """ + pass + + def configure_for_test(self): + """Method called by the model.evaluate() to check any custom model's + configuration before calling keras parent class `evaluate()` + """ + pass + @Block.registry.register_with_multiple_names("seq_predict_next") @tf.keras.utils.register_keras_serializable(package="merlin_models") @@ -247,6 +266,75 @@ def compute_output_shape(self, input_shape): return new_input_shapes + def compute_mask(self, inputs, mask=None): + new_item_id_seq = inputs[self.target_name][:, :-1] + target_mask = tf.RaggedTensor.from_row_lengths( + values=tf.ones_like(new_item_id_seq.flat_values, dtype=tf.bool), + row_lengths=new_item_id_seq.row_lengths(), + ) + self.target_mask = tf.squeeze(target_mask, axis=-1) + + targets_mask = dict({self.target_name: self.target_mask}) + inputs_mask = dict() + for k, v in inputs.items(): + if k in self.schema.column_names: + inputs_mask[k] = self.target_mask + else: + inputs_mask[k] = None + return (inputs_mask, targets_mask) + + def configure_for_train(self): + """Method called by the model.fit() to set the specialized + `masking_post` and `masking_pre` needed by the TransformerBlock + to align with the SequencePredictNext outputs. + """ + if self.transformer is not None: + if not is_transformers_available(): + raise ImportError("HuggingFace library `transformers` is required") + from merlin.models.tf.transformers.transforms import ( + TransformerInferenceHiddenState, + TransformerOutputToRagged, + ) + + # set the tansformer block with the correct masking block + self.transformer.masking_post = combinators.SequentialBlock( + [TransformerOutputToRagged(), TransformerInferenceHiddenState()] + ) + self.transformer.masking_pre = combinators.SequentialBlock( + [SequenceCausalLastInference(), ExtractMaskFromTargets()] + ) + + def configure_for_test(self): + """Method called by the model.evaluate() to check that the + `masking_post` and `masking_pre` set in the TransformerBlock + are aligned with the evaluation strategy of SequencePredictNext + """ + if self.transformer is not None: + if self.transformer.masking_pre is None: + raise ValueError( + "To evaluate using `SequencePredictNext`, ensure that your TransformerBlock has" + " `masking_pre` set as" + " `combinators.SequentialBlock(" + " [SequenceCausalLastInference(), ExtractMaskFromTargets()]" + ")`." + " You can automatically set `masking_pre` by passing `SequencePredictNext`" + " as the `pre` argument to the `fit` method: " + "`model.fit(..., pre=SequencePredictNext(...))`." + ) + + if any( + isinstance(layer, ReplaceMaskedEmbeddings) + for layer in self.transformer.masking_pre.layers + ): + ValueError( + "You cannot use `ReplaceMaskedEmbeddings` as `masking_pre`" + " of your TransformerBlock with the `SequencePredictNext`" + " evaluation strategy. Please ensure that your Transformer" + " model has been trained with `SequencePredictNext`" + " by passing it as the `pre` argument to the `fit` method: " + "`model.fit(..., pre=SequencePredictNext(...))`." + ) + @Block.registry.register_with_multiple_names("seq_predict_last") @tf.keras.utils.register_keras_serializable(package="merlin_models") @@ -306,6 +394,41 @@ def compute_output_shape(self, input_shape): return new_input_shapes + def compute_mask(self, inputs, mask=None): + new_item_id_seq = inputs[self.target_name][:, :-1] + self.target_mask = self._generate_target_mask(new_item_id_seq) + inputs_mask = dict() + for k, v in inputs.items(): + if k in self.schema.column_names: + inputs_mask[k] = self.target_mask + else: + inputs_mask[k] = None + + return (inputs_mask, self.target_mask) + + def _generate_target_mask(self, ids_seq: tf.RaggedTensor) -> tf.RaggedTensor: + """Returns a bool ragged tensor with the last positions of the sequence masked + + Parameters + ---------- + ids_seq : tf.RaggedTensor + Sequence of ids, which are used to infer how many values + each sequence contains + + Returns + ------- + tf.RaggedTensor + Mask tensor, with True at the last positions + """ + row_lengths = ids_seq.row_lengths(1) + max_seq_length = tf.cast(tf.reduce_max(row_lengths), tf.int32) + + padding_mask = tf.sequence_mask(row_lengths) + targets_mask = tf.ragged.boolean_mask( + tf.cast(tf.one_hot(row_lengths - 1, max_seq_length), tf.bool), padding_mask + ) + return targets_mask + @Block.registry.register_with_multiple_names("seq_predict_random") @tf.keras.utils.register_keras_serializable(package="merlin_models") @@ -354,7 +477,7 @@ def call( positions_matrix = tf.tile( tf.expand_dims(tf.range(0, max_length, dtype=tf.int32), 0), [batch_size, 1] ) - input_mask = positions_matrix < random_targets_indices + self.random_mask = positions_matrix < random_targets_indices target_mask = positions_matrix == random_targets_indices new_target = tf.squeeze( @@ -370,12 +493,48 @@ def call( new_inputs = dict() for k, v in inputs.items(): if k in self.schema.column_names: - new_inputs[k] = tf.ragged.boolean_mask(v, input_mask) + new_inputs[k] = tf.ragged.boolean_mask(v, self.random_mask) else: new_inputs[k] = v return (new_inputs, targets) + def compute_mask(self, inputs, mask=None): + new_item_id_seq = tf.ragged.boolean_mask(inputs[self.target_name], self.random_mask) + + self.target_mask = self._generate_target_mask(new_item_id_seq) + inputs_mask = dict() + for k, v in inputs.items(): + if k in self.schema.column_names: + inputs_mask[k] = self.target_mask + else: + inputs_mask[k] = None + + return (inputs_mask, self.target_mask) + + def _generate_target_mask(self, ids_seq: tf.RaggedTensor) -> tf.RaggedTensor: + """Returns a bool ragged tensor with the last positions of the sequence masked + + Parameters + ---------- + ids_seq : tf.RaggedTensor + Sequence of ids, which are used to infer how many values + each sequence contains + + Returns + ------- + tf.RaggedTensor + Mask tensor, with True at the last positions + """ + row_lengths = ids_seq.row_lengths(1) + max_seq_length = tf.cast(tf.reduce_max(row_lengths), tf.int32) + + padding_mask = tf.sequence_mask(row_lengths) + targets_mask = tf.ragged.boolean_mask( + tf.cast(tf.one_hot(row_lengths - 1, max_seq_length), tf.bool), padding_mask + ) + return targets_mask + @Block.registry.register_with_multiple_names("seq_target_as_input") @tf.keras.utils.register_keras_serializable(package="merlin_models") @@ -575,6 +734,63 @@ def from_config(cls, config): masking_prob = config.pop("masking_prob") return cls(schema, target, masking_prob, **config) + def configure_for_train(self): + """Method called by the model.fit() to set the specialized + `masking_post` and `masking_pre` needed by the TransformerBlock + to align with the SequencePredictNext outputs. + """ + if self.transformer is not None: + if not is_transformers_available(): + raise ImportError("HuggingFace library `transformers` is required") + from merlin.models.tf.transformers.transforms import ( + TransformerInferenceHiddenState, + TransformerOutputToRagged, + ) + + # set the tansformer block with the correct masking blocks + self.transformer.masking_post = combinators.SequentialBlock( + [TransformerOutputToRagged(), TransformerInferenceHiddenState()] + ) + self.transformer.masking_pre = combinators.SequentialBlock( + [SequenceMaskLastInference(), ExtractMaskFromTargets(), ReplaceMaskedEmbeddings()] + ) + + def configure_for_test(self): + """Method called by the model.evaluate() to check that the + `masking_pre` set in the TransformerBlock is aligned with + the evaluation strategy of SequenceMaskRandom + """ + if self.transformer is not None: + if self.transformer.masking_pre is None: + raise ValueError( + "To evaluate using `SequenceMaskRandom`, ensure that your TransformerBlock has" + " `masking_pre` set as" + " `combinators.SequentialBlock(" + " [" + " SequenceMaskLastInference()," + " ExtractMaskFromTargets()," + " ReplaceMaskedEmbeddings()" + " ]" + ")`" + " You can automatically set `masking_pre` by passing `SequenceMaskRandom`" + " as the `pre` argument to the `fit` method:" + " `model.fit(..., pre=SequenceMaskRandom(...))`." + ) + + if not any( + isinstance(layer, ReplaceMaskedEmbeddings) + for layer in self.transformer.masking_pre.layers + ): + ValueError( + " The block `ReplaceMaskedEmbeddings` must be part of the `masking_pre`" + " of your TransformerBlock to be able to use `SequenceMaskRandom`" + " evaluation strategy." + " Please ensure that your Transformer model has been trained with" + " `SequenceMaskRandom` or `SequenceMaskLast`" + " by passing it as the `pre` argument to the `fit` method: " + "`model.fit(..., pre=SequenceMaskRandom(...))`." + ) + @tf.keras.utils.register_keras_serializable(package="merlin.models") class SequenceMaskLast(SequenceTargetAsInput): @@ -646,6 +862,63 @@ def from_config(cls, config): target = config.pop("target") return cls(schema, target, **config) + def configure_for_train(self): + """Method called by the model.fit() to set the specialized + `masking_post` and `masking_pre` needed by the TransformerBlock + to align with the SequencePredictNext outputs. + """ + if self.transformer is not None: + if not is_transformers_available(): + raise ImportError("HuggingFace library `transformers` is required") + from merlin.models.tf.transformers.transforms import ( + TransformerInferenceHiddenState, + TransformerOutputToRagged, + ) + + # set the tansformer block with the correct masking blocks + self.transformer.masking_post = combinators.SequentialBlock( + [TransformerOutputToRagged(), TransformerInferenceHiddenState()] + ) + self.transformer.masking_pre = combinators.SequentialBlock( + [SequenceMaskLastInference(), ExtractMaskFromTargets(), ReplaceMaskedEmbeddings()] + ) + + def configure_for_test(self): + """Method called by the model.evaluate() to check that the + `masking_pre` set in the TransformerBlock is aligned with + the evaluation strategy of SequenceMaskRandom + """ + if self.transformer is not None: + if self.transformer.masking_pre is None: + raise ValueError( + "To evaluate using `SequenceMaskLast`, ensure that your TransformerBlock has" + " `masking_pre` set as" + " `combinators.SequentialBlock(" + " [" + " SequenceMaskLastInference()," + " ExtractMaskFromTargets()," + " ReplaceMaskedEmbeddings()" + " ]" + ")`" + " You can automatically set `masking_pre` by passing `SequenceMaskRandom`" + " or `SequenceMaskLast` as the `pre` argument to the `fit` method:" + " `model.fit(..., pre=SequenceMaskRandom(...))`." + ) + + if not any( + isinstance(layer, ReplaceMaskedEmbeddings) + for layer in self.transformer.masking_pre.layers + ): + ValueError( + "The block `ReplaceMaskedEmbeddings` must be part of the `masking_pre`" + " of your TransformerBlock to be able to use `SequenceMaskRandom`" + " evaluation strategy." + " Please ensure that your Transformer model has been trained with" + " `SequenceMaskRandom` or `SequenceMaskLast`" + " by passing it as the `pre` argument to the `fit` method: " + "`model.fit(..., pre=SequenceMaskLast(...))`." + ) + @tf.keras.utils.register_keras_serializable(package="merlin.models") class SequenceMaskLastInference(Block): @@ -659,7 +932,7 @@ def call(self, inputs, training=False, testing=False): return inputs def compute_mask(self, inputs, mask=None): - """Selects (masks) the nex position after the + """Selects (masks) the next position after the last valid (non-padded) position of the sequential targets to be predicted. This method is called by Keras after call() @@ -685,23 +958,18 @@ def compute_mask(self, inputs, mask=None): @tf.keras.utils.register_keras_serializable(package="merlin.models") class ReplaceMaskedEmbeddings(Block): """Takes a 3D input tensor (batch size x seq. length x embedding dim) and replaces - by a dummy trainable single embedding at the positions to be masked. - This block looks for the Keras mask (`._keras_mask`) in the following order: - 1. Checks if the input tensor has a mask - 2. Checks if there is a single target and if it has a mask - 3. If there are multiple targets (dict) returns the mask of the target - that matches the first 2 dims of the input - This is useful to be used when PredictMasked() transformation is used in - the Loader, which randomly selects some targets to be predicted and uses - Keras Masking to cascade the `_keras_mask`. By replacing input embeddings - at masked positions we avoid target leakage when training models with - Masked Language Modeling (BERT-like) - - **Note:** To support inference, the input sequence and its corresponding mask should be - extended by one position at the end to account for the next-item (`target`) position. - To do this, you should set `SequenceMaskLastInference` as a pre-layer of - `ReplaceMaskedEmbeddings()` using the sequential-block: - ```mm.SequentialBlock([mm.SequenceMaskLastInference(), mm.ReplaceMaskedEmbeddings()])``` + by a dummy trainable single embedding at the positions to be masked. + + This is useful to be used when PredictMasked() transformation is used in + the fit()/eval() methods, which randomly selects some targets to be predicted and uses + Keras Masking to cascade the `_keras_mask`. By replacing input embeddings + at masked positions we avoid target leakage when training models with + Masked Language Modeling (BERT-like). + + To support masked training approach in Transformer-based model, + SequenceMaskRandom and SequenceLastRandom implements `configure_for_train` method + that sets `ReplaceMaskedEmbeddings` as part of the `masking_pre` of + the transformer block. """ def __init__(self, **kwargs): @@ -721,10 +989,8 @@ def build(self, input_shape): def call( self, inputs: Union[tf.Tensor, tf.RaggedTensor], - targets: Optional[Union[tf.Tensor, tf.RaggedTensor, TabularData]] = None, ) -> Union[tf.Tensor, tf.RaggedTensor]: - """If the sequence of input embeddings or the corresponding sequential - targets is masked (with `tensor._keras_mask` defined), + """If the sequence of input embeddings is masked (with `tensor._keras_mask` defined), replaces the input embeddings for masked elements Parameters ---------- @@ -732,22 +998,70 @@ def call( A tensor with sequences of vectors. Needs to be 3D (batch_size, sequence_length, embeddings dim). If inputs._keras_mask is defined uses it to infer the mask - targets : Union[tf.Tensor, tf.RaggedTensor, TabularData], optional - The target values, from which the mask can be extracted - if targets inputs._keras_mask is defined. + Returns ------- Union[tf.Tensor, tf.RaggedTensor] - If training, returns a tensor with the masked inputs replaced by the dummy embedding + returns a tensor with the masked inputs replaced by the dummy embedding """ outputs = inputs - # Infers the mask from the inputs or targets - mask = self._infer_mask_from_inputs_or_targets(inputs, targets) - if mask is not None: + if getattr(inputs, "_keras_mask", None) is not None: # Replaces the embeddings at masked positions by a dummy trainable embedding - outputs = self._replace_masked_embeddings(inputs, mask) + outputs = self._replace_masked_embeddings(inputs, inputs._keras_mask) return outputs + def _replace_masked_embeddings( + self, inputs: Union[tf.Tensor, tf.RaggedTensor], mask: Union[tf.Tensor, tf.RaggedTensor] + ) -> tf.RaggedTensor: + """ + Replaces in the inputs tensors the values masked as targets by a common trainable + embedding + """ + + tf.Assert( + tf_utils.check_inputs_mask_compatible_shape(inputs, mask), + [ + "The inputs and mask need to be compatible: have the same dtype " + "(tf.Tensor or tf.RaggedTensor) and the tf.rank(mask) == tf.rank(inputs)-1" + ], + ) + + if isinstance(mask, tf.RaggedTensor): + mask = mask.with_row_splits_dtype(inputs.row_splits.dtype) + + output = tf.where( + tf.cast(tf.expand_dims(mask, -1), tf.bool), + tf.cast(self.masked_embedding, dtype=inputs.dtype), + inputs, + ) + return output + + +@tf.keras.utils.register_keras_serializable(package="merlin.models") +class ExtractMaskFromTargets(Block): + """ + Recovers the mask information for the inputs from the mask information + stored in the targets. + + This block looks for the Keras mask (`._keras_mask`) in the following order: + 1. Checks if the input tensor has a mask. + 2. Checks if there is a single target and if it has a mask. + 3. If there are multiple targets (dictionary), returns the mask of the target + that matches the first two dimensions of the input. + + This is useful to use when the mask information for the inputs may be lost in + previous non-mask-aware Merlin blocks. + """ + + def call( + self, + inputs: Union[tf.Tensor, tf.RaggedTensor], + targets: Optional[Union[tf.Tensor, tf.RaggedTensor, TabularData]] = None, + ) -> Union[tf.Tensor, tf.RaggedTensor]: + mask = self._infer_mask_from_inputs_or_targets(inputs, targets) + inputs._keras_mask = mask + return inputs + def _infer_mask_from_inputs_or_targets( self, inputs: Union[tf.Tensor, tf.RaggedTensor], @@ -769,7 +1083,7 @@ def _infer_mask_from_inputs_or_targets( for _, v in targets.items(): if getattr( v, "_keras_mask", None - ) is not None and self._check_inputs_mask_compatible_shape( + ) is not None and tf_utils.check_inputs_mask_compatible_shape( inputs, v._keras_mask ): if mask is None: @@ -786,41 +1100,28 @@ def _infer_mask_from_inputs_or_targets( return mask - def _check_inputs_mask_compatible_shape( - self, inputs: Union[tf.Tensor, tf.RaggedTensor], mask: Union[tf.Tensor, tf.RaggedTensor] - ): - result = False - if type(inputs) == type(mask) and (inputs.shape.as_list()[:-1] == mask.shape.as_list()): - if isinstance(inputs, tf.RaggedTensor): - result = tf.reduce_all( - tf.cast(inputs.row_lengths(), tf.int32) == tf.cast(mask.row_lengths(), tf.int32) - ) - else: - result = True - return result - def _replace_masked_embeddings( - self, inputs: Union[tf.Tensor, tf.RaggedTensor], mask: Union[tf.Tensor, tf.RaggedTensor] - ) -> tf.RaggedTensor: - """ - Replaces in the inputs tensors the values masked as targets by a common trainable - embedding - """ - - tf.Assert( - self._check_inputs_mask_compatible_shape(inputs, mask), - [ - "The inputs and mask need to be compatible: have the same dtype " - "(tf.Tensor or tf.RaggedTensor) and the tf.rank(mask) == tf.rank(inputs)-1" - ], - ) +@tf.keras.utils.register_keras_serializable(package="merlin.models") +class SequenceCausalLastInference(Block): + def call(self, inputs, training=False, testing=False): + self.inference_mode = not training and not testing + return inputs - if isinstance(mask, tf.RaggedTensor): - mask = mask.with_row_splits_dtype(inputs.row_splits.dtype) + def compute_mask(self, inputs, mask=None): + """Selects (masks) the last non padded position of the + input sequence to be predicted. + This method is called by Keras after call() + and returns the mask that is going to be assigned + to the input tensors, being accessible + by tensor._keras_mask + """ + if self.inference_mode: + if isinstance(inputs, tf.RaggedTensor): + row_lengths = inputs.row_lengths(1) + max_seq_length = tf.cast(tf.reduce_max(row_lengths), tf.int32) - output = tf.where( - tf.cast(tf.expand_dims(mask, -1), tf.bool), - tf.cast(self.masked_embedding, dtype=inputs.dtype), - inputs, - ) - return output + padding_mask = tf.sequence_mask(row_lengths) + mask = tf.ragged.boolean_mask( + tf.cast(tf.one_hot(row_lengths - 1, max_seq_length), tf.bool), padding_mask + ) + return mask diff --git a/merlin/models/tf/utils/tf_utils.py b/merlin/models/tf/utils/tf_utils.py index a29efbf970..597f7348d1 100644 --- a/merlin/models/tf/utils/tf_utils.py +++ b/merlin/models/tf/utils/tf_utils.py @@ -482,3 +482,35 @@ def list_col_to_ragged(values: tf.Tensor, offsets: tf.Tensor): offsets = tf.cast(offsets, tf.int32) return tf.RaggedTensor.from_row_splits(values, offsets) + + +def check_inputs_mask_compatible_shape( + inputs: Union[tf.Tensor, tf.RaggedTensor], mask: Union[tf.Tensor, tf.RaggedTensor] +): + """Check if the shape and the type of the input and mask tensors are compatible. + Parameters + ---------- + inputs : Union[tf.Tensor, tf.RaggedTensor] + The input tensor, which can be either a dense or ragged tensor. + mask : Union[tf.Tensor, tf.RaggedTensor] + The mask tensor, which can be either a dense or ragged tensor. + + Returns + ------- + bool: + Returns True if the shape of the input and mask tensors are compatible, False otherwise. + + Notes + ----- + The function assumes that the `inputs` tensor has one more dimension than the `mask` tensor, + with the extra dimension typically related to the embeddings dimension. + """ + result = False + if type(inputs) == type(mask) and (inputs.shape.as_list()[:-1] == mask.shape.as_list()): + if isinstance(inputs, tf.RaggedTensor): + result = tf.reduce_all( + tf.cast(inputs.row_lengths(), tf.int32) == tf.cast(mask.row_lengths(), tf.int32) + ) + else: + result = True + return result diff --git a/tests/unit/tf/transformers/test_block.py b/tests/unit/tf/transformers/test_block.py index 4db67d8764..3dd89beff7 100644 --- a/tests/unit/tf/transformers/test_block.py +++ b/tests/unit/tf/transformers/test_block.py @@ -229,7 +229,7 @@ def classification_loader(sequence_testing_data: Dataset): @pytest.mark.parametrize("run_eagerly", [True, False]) -def test_transformer_with_causal_language_modeling(sequence_testing_data: Dataset, run_eagerly): +def test_transformer_with_predict_random(sequence_testing_data: Dataset, run_eagerly): seq_schema = sequence_testing_data.schema.select_by_tag(Tags.SEQUENCE).select_by_tag( Tags.CATEGORICAL ) @@ -237,12 +237,53 @@ def test_transformer_with_causal_language_modeling(sequence_testing_data: Datase target = target_schema.column_names[0] sequence_testing_data.schema = seq_schema + target_schema + model_schema = sequence_testing_data.schema + + transformer_input_dim = 48 + transformer_block = GPT2Block(d_model=transformer_input_dim, n_head=8, n_layer=2) + model = mm.Model( + mm.InputBlockV2( + model_schema, + categorical=mm.Embeddings( + model_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None + ), + ), + mm.MLPBlock([transformer_input_dim]), + transformer_block, + mm.CategoricalOutput( + model_schema.select_by_name(target), default_loss="categorical_crossentropy" + ), + ) - predict_next = mm.SequencePredictNext(schema=seq_schema, target=target) + predict_next = mm.SequencePredictRandom( + schema=seq_schema, target=target, transformer=transformer_block + ) loader = Loader(sequence_testing_data, batch_size=8, shuffle=False) + + testing_utils.model_test( + model, loader, run_eagerly=run_eagerly, reload_model=True, fit_kwargs={"pre": predict_next} + ) + + predict_last = mm.SequencePredictLast( + schema=seq_schema, target=target, transformer=transformer_block + ) + metrics = model.evaluate(loader, batch_size=8, steps=1, return_dict=True, pre=predict_last) + assert len(metrics) > 0 + + +@pytest.mark.parametrize("run_eagerly", [True, False]) +def test_transformer_with_causal_language_modeling(sequence_testing_data: Dataset, run_eagerly): + seq_schema = sequence_testing_data.schema.select_by_tag(Tags.SEQUENCE).select_by_tag( + Tags.CATEGORICAL + ) + target_schema = sequence_testing_data.schema.select_by_tag(Tags.ITEM_ID) + target = target_schema.column_names[0] + + sequence_testing_data.schema = seq_schema + target_schema model_schema = sequence_testing_data.schema transformer_input_dim = 48 + transformer_block = GPT2Block(d_model=transformer_input_dim, n_head=8, n_layer=2) model = mm.Model( mm.InputBlockV2( model_schema, @@ -251,24 +292,36 @@ def test_transformer_with_causal_language_modeling(sequence_testing_data: Datase ), ), mm.MLPBlock([transformer_input_dim]), - GPT2Block(d_model=transformer_input_dim, n_head=8, n_layer=2), + transformer_block, mm.CategoricalOutput( model_schema.select_by_name(target), default_loss="categorical_crossentropy" ), ) - batch = next(iter(loader))[0] - outputs = model(batch) - assert list(outputs.shape) == [8, 4, 51997] + predict_next = mm.SequencePredictNext( + schema=seq_schema, target=target, transformer=transformer_block + ) + loader = Loader(sequence_testing_data, batch_size=8, shuffle=False) + testing_utils.model_test( model, loader, run_eagerly=run_eagerly, reload_model=True, fit_kwargs={"pre": predict_next} ) + batch = next(iter(loader))[0] + outputs = model(batch) + assert list(outputs.shape) == [8, 51997] + metrics = model.evaluate(loader, batch_size=8, steps=1, return_dict=True, pre=predict_next) assert len(metrics) > 0 predictions = model.predict(loader, batch_size=8, steps=1) - assert predictions.shape == (8, 4, 51997) + assert predictions.shape == (8, 51997) + + predict_last = mm.SequencePredictLast( + schema=seq_schema, target=target, transformer=transformer_block + ) + metrics = model.evaluate(loader, batch_size=8, steps=1, return_dict=True, pre=predict_last) + assert len(metrics) > 0 @pytest.mark.parametrize("run_eagerly", [True, False]) @@ -283,6 +336,7 @@ def test_transformer_with_masked_language_modeling(sequence_testing_data: Datase loader = Loader(sequence_testing_data, batch_size=8, shuffle=False) transformer_input_dim = 48 + transformer_block = XLNetBlock(d_model=transformer_input_dim, n_head=8, n_layer=2) model = mm.Model( mm.InputBlockV2( seq_schema, @@ -291,19 +345,15 @@ def test_transformer_with_masked_language_modeling(sequence_testing_data: Datase ), ), mm.MLPBlock([transformer_input_dim]), - XLNetBlock( - d_model=transformer_input_dim, - n_head=8, - n_layer=2, - pre=mm.SequentialBlock([mm.SequenceMaskLastInference(), mm.ReplaceMaskedEmbeddings()]), - post="inference_hidden_state", - ), + transformer_block, mm.CategoricalOutput( seq_schema.select_by_name(target), default_loss="categorical_crossentropy", ), ) - seq_mask_random = mm.SequenceMaskRandom(schema=seq_schema, target=target, masking_prob=0.3) + seq_mask_random = mm.SequenceMaskRandom( + schema=seq_schema, target=target, masking_prob=0.3, transformer=transformer_block + ) inputs, targets = loader.peek() @@ -317,7 +367,9 @@ def test_transformer_with_masked_language_modeling(sequence_testing_data: Datase fit_kwargs={"pre": seq_mask_random}, ) - seq_mask_last = mm.SequenceMaskLast(schema=seq_schema, target=target) + seq_mask_last = mm.SequenceMaskLast( + schema=seq_schema, target=target, transformer=transformer_block + ) metrics = model.evaluate(loader, batch_size=8, steps=1, return_dict=True, pre=seq_mask_last) assert len(metrics) > 0 @@ -340,6 +392,7 @@ def test_transformer_with_masked_language_modeling_check_eval_masked( loader = Loader(sequence_testing_data, batch_size=8, shuffle=False) transformer_input_dim = 48 + transformer_block = BertBlock(d_model=transformer_input_dim, n_head=8, n_layer=2) model = mm.Model( mm.InputBlockV2( seq_schema, @@ -348,19 +401,15 @@ def test_transformer_with_masked_language_modeling_check_eval_masked( ), ), mm.MLPBlock([transformer_input_dim]), - XLNetBlock( - d_model=transformer_input_dim, n_head=8, n_layer=2, pre=mm.ReplaceMaskedEmbeddings() - ), + transformer_block, mm.CategoricalOutput( seq_schema.select_by_name(target), default_loss="categorical_crossentropy", ), ) - seq_mask_random = mm.SequenceMaskRandom(schema=seq_schema, target=target, masking_prob=0.3) - - inputs = itertools.islice(iter(loader), 1) - outputs = model.predict(inputs, pre=seq_mask_random) - assert list(outputs.shape) == [8, 4, 51997] + seq_mask_random = mm.SequenceMaskRandom( + schema=seq_schema, target=target, masking_prob=0.3, transformer=transformer_block + ) testing_utils.model_test( model, @@ -371,6 +420,10 @@ def test_transformer_with_masked_language_modeling_check_eval_masked( metrics=[mm.RecallAt(5000), mm.NDCGAt(5000, seed=4)], ) + inputs = itertools.islice(iter(loader), 1) + outputs = model.predict(inputs, pre=seq_mask_random) + assert list(outputs.shape) == [8, 51997] + # This transform only extracts targets, but without applying mask seq_target_as_input_no_mask = mm.SequenceTargetAsInput(schema=seq_schema, target=target) @@ -429,19 +482,14 @@ def test_transformer_model_with_masking_and_broadcast_to_sequence( dmodel = 32 mlp_block = mm.MLPBlock([128, dmodel], activation="relu") - - dense_block = mm.SequentialBlock( - input_block, - mlp_block, - mm.GPT2Block( - d_model=dmodel, - n_head=4, - n_layer=2, - pre=mm.ReplaceMaskedEmbeddings(), - post="inference_hidden_state", - ), + transformer_block = mm.GPT2Block( + d_model=dmodel, + n_head=4, + n_layer=2, ) + dense_block = mm.SequentialBlock(input_block, mlp_block, transformer_block) + mlp_block2 = mm.MLPBlock([128, dmodel], activation="relu") prediction_task = mm.CategoricalOutput( @@ -449,7 +497,9 @@ def test_transformer_model_with_masking_and_broadcast_to_sequence( ) model = mm.Model(dense_block, mlp_block2, prediction_task) - fit_pre = mm.SequenceMaskRandom(schema=seq_schema, target=target, masking_prob=0.3) + fit_pre = mm.SequenceMaskRandom( + schema=seq_schema, target=target, masking_prob=0.3, transformer=transformer_block + ) testing_utils.model_test( model, sequence_testing_data, diff --git a/tests/unit/tf/transforms/test_sequence.py b/tests/unit/tf/transforms/test_sequence.py index 7e5f4c5af9..0dc8f5a960 100644 --- a/tests/unit/tf/transforms/test_sequence.py +++ b/tests/unit/tf/transforms/test_sequence.py @@ -194,10 +194,11 @@ def test_seq_predict_masked_serialize_deserialize(sequence_testing_data): def test_seq_mask_random_replace_embeddings( sequence_testing_data: Dataset, dense: bool, target_as_dict: bool ): + from merlin.models.tf.transforms.sequence import ExtractMaskFromTargets + sequence_testing_data.schema = sequence_testing_data.schema.select_by_tag( Tags.SEQUENCE ).select_by_name(["item_id_seq", "categories"]) - target = sequence_testing_data.schema.select_by_tag(Tags.ITEM_ID).column_names[0] predict_masked = mm.SequenceMaskRandom( schema=sequence_testing_data.schema, target=target, masking_prob=0.3 @@ -221,7 +222,7 @@ def test_seq_mask_random_replace_embeddings( # Making targets different in dict, the valid one is "target2" which is 2D targets = {"target1": tf.ragged.constant([1, 2, 3, 4, 5, 6, 7, 8]), "target2": targets} - masked_embeddings = mm.ReplaceMaskedEmbeddings() + masked_embeddings = mm.SequentialBlock(ExtractMaskFromTargets(), mm.ReplaceMaskedEmbeddings()) output = masked_embeddings(item_id_emb_seq, targets=targets, training=True) replaced_mask = tf.logical_not(tf.reduce_all(output == item_id_emb_seq, axis=2)) @@ -233,23 +234,27 @@ def test_seq_mask_random_replace_embeddings( def test_replace_masked_input_embeddings_no_target(): + from merlin.models.tf.transforms.sequence import ExtractMaskFromTargets + item_id_emb_seq = tf.random.uniform((8, 10), dtype=tf.float32) item_id_emb_seq._keras_mask = tf.cast( tf.random.uniform((8,), minval=0, maxval=2, dtype=tf.int32), tf.bool ) targets = None - masked_embeddings = mm.ReplaceMaskedEmbeddings() + masked_embeddings = mm.SequentialBlock(ExtractMaskFromTargets(), mm.ReplaceMaskedEmbeddings()) output = masked_embeddings(item_id_emb_seq, targets=targets, training=True) # Checks that no input embedding was replaced, as there was no masking defined tf.Assert(tf.logical_not(tf.reduce_all(output == item_id_emb_seq)), []) def test_not_replace_unmasked_sequence_embeddings(): + from merlin.models.tf.transforms.sequence import ExtractMaskFromTargets + item_id_emb_seq = tf.random.uniform((8, 10), dtype=tf.float32) targets = tf.random.uniform((8, 10), dtype=tf.float32) - masked_embeddings = mm.ReplaceMaskedEmbeddings() + masked_embeddings = mm.SequentialBlock(ExtractMaskFromTargets(), mm.ReplaceMaskedEmbeddings()) output = masked_embeddings(item_id_emb_seq, targets=targets, training=True) # Checks that no input embedding was replaced, as there was no masking defined tf.Assert(tf.reduce_all(output == item_id_emb_seq), [])