diff --git a/merlin/models/tf/inputs/continuous.py b/merlin/models/tf/inputs/continuous.py index 24210247c9..f5bdcc2067 100644 --- a/merlin/models/tf/inputs/continuous.py +++ b/merlin/models/tf/inputs/continuous.py @@ -47,6 +47,7 @@ def __init__( ): if inputs is None: inputs = Tags.CONTINUOUS + self.supports_masking = True super().__init__(inputs, **kwargs) diff --git a/merlin/models/tf/transforms/features.py b/merlin/models/tf/transforms/features.py index 2f0406f0c5..d4c415b91a 100644 --- a/merlin/models/tf/transforms/features.py +++ b/merlin/models/tf/transforms/features.py @@ -812,7 +812,7 @@ def reshape_categorical_input_tensor_for_encoding( @tf.keras.utils.register_keras_serializable(package="merlin.models") -class BroadcastToSequence(tf.keras.layers.Layer): +class BroadcastToSequence(Block): """Broadcast context features to match the timesteps of sequence features. This layer supports mask propagation. If the sequence features have a mask. The @@ -833,77 +833,121 @@ def __init__(self, context_schema: Schema, sequence_schema: Schema, **kwargs): self.sequence_schema = sequence_schema def call(self, inputs: TabularData) -> TabularData: - inputs = self._broadcast(inputs, inputs) + inputs = self._broadcast(inputs) return inputs - def _get_seq_features_shapes(self, inputs: TabularData): - inputs_sizes = {k: v.shape for k, v in inputs.items()} + def _check_sequence_features(self, inputs: TabularData): + sequence_features = self.sequence_schema.column_names - seq_features_shapes = dict() - for fname, fshape in inputs_sizes.items(): - # Saves the shapes of sequential features - if fname in self.sequence_schema.column_names: - seq_features_shapes[fname] = tuple(fshape[:2]) + if len(sequence_features) == 0: + return + + not_found_seq_features = set(sequence_features).difference(set(inputs.keys())) + if len(not_found_seq_features) > 0: + raise ValueError( + f"Some sequential features were not found in the inputs: {not_found_seq_features}" + ) sequence_length = None sequence_is_ragged = None - if len(seq_features_shapes) > 0: - for k, v in inputs.items(): - if k in self.sequence_schema.column_names: - if isinstance(v, tf.RaggedTensor): - if sequence_is_ragged is False: - raise ValueError( - "sequence features must all be ragged or all dense, not both." - ) - new_sequence_length = v.row_lengths() - sequence_is_ragged = True - else: - if sequence_is_ragged is True: - raise ValueError( - "sequence features must all be ragged or all dense, not both." - ) - new_sequence_length = [v.shape[1]] - sequence_is_ragged = False - - # check sequences lengths match - if sequence_length is not None: - sequence_lengths_equal = tf.math.reduce_all( - tf.equal(new_sequence_length, sequence_length) + for k, v in inputs.items(): + if k in sequence_features: + if isinstance(v, tf.RaggedTensor): + if sequence_is_ragged is False: + raise ValueError( + "Sequential features must all be ragged or all dense, not both." ) - tf.Assert( - sequence_lengths_equal, - [ - "sequence features must share the same sequence lengths", - (sequence_length, new_sequence_length), - ], + new_sequence_length = v.row_lengths() + sequence_is_ragged = True + else: + if sequence_is_ragged is True: + raise ValueError( + "Sequential features must all be ragged or all dense, not both." ) - sequence_length = new_sequence_length + new_sequence_length = [v.shape[1]] + sequence_is_ragged = False + + # check sequences lengths match + if sequence_length is not None: + sequence_lengths_equal = tf.math.reduce_all( + tf.equal(new_sequence_length, sequence_length) + ) + tf.Assert( + sequence_lengths_equal, + [ + "Sequential features must share the same sequence lengths", + (sequence_length, new_sequence_length), + ], + ) + sequence_length = new_sequence_length + + def _check_context_features(self, inputs: TabularData): + context_features = self.context_schema.column_names + + if len(context_features) == 0: + return + + not_found_seq_features = set(context_features).difference(set(inputs.keys())) + if len(not_found_seq_features) > 0: + raise ValueError( + f"Some contextual features were not found in the inputs: {not_found_seq_features}" + ) - return seq_features_shapes, sequence_length + for k in context_features: + v = inputs[k] + if not isinstance(v, tf.Tensor): + raise ValueError(f"A contextual feature ({k}) should be a dense tf.Tensor") + + if len(v.shape) >= 3: + raise ValueError( + f"A contextual feature ({k}) should be a 1D or " "2D tf.Tensor: {v.shape}." + ) @tf.function - def _broadcast(self, inputs, target): - seq_features_shapes, sequence_length = self._get_seq_features_shapes(inputs) - if len(seq_features_shapes) > 0: - non_seq_features = set(inputs.keys()).difference(set(seq_features_shapes.keys())) - non_seq_target = {} - for fname in non_seq_features: - if fname in self.context_schema.column_names: - if target[fname] is None: - continue - if isinstance(sequence_length, tf.Tensor): - non_seq_target[fname] = tf.RaggedTensor.from_row_lengths( - tf.repeat(target[fname], sequence_length, axis=0), sequence_length - ) - else: - shape = target[fname].shape - target_shape = shape[:1] + sequence_length + shape[1:] - non_seq_target[fname] = tf.broadcast_to( - tf.expand_dims(target[fname], 1), target_shape - ) - target = {**target, **non_seq_target} + def _broadcast(self, inputs): + self._check_sequence_features(inputs) + self._check_context_features(inputs) + + sequence_features = self.sequence_schema.column_names + context_features = self.context_schema.column_names - return target + if len(sequence_features) == 0 and len(context_features) == 0: + return inputs + + sequence_features_values = list( + [inputs[k] for k in sequence_features if inputs[k] is not None] + ) + if len(sequence_features_values) == 0: + return inputs + template_seq_feature_value = sequence_features_values[0] + + non_seq_target = {} + for fname in context_features: + if inputs[fname] is None: + continue + + if isinstance(template_seq_feature_value, tf.RaggedTensor): + new_value = inputs[fname] + while len(new_value.shape) < len(template_seq_feature_value.shape): + new_value = tf.expand_dims(new_value, 1) + + # Here broadcast the context feature using the same shape + # of a 3D ragged sequential feature with compatible + # So that the context feature shape becomes (batch size, seq length, feature dim) + non_seq_target[fname] = ( + tf.ones_like(template_seq_feature_value[..., :1], dtype=new_value.dtype) + * new_value + ) + else: + shape = inputs[fname].shape + sequence_length = template_seq_feature_value.shape[1] + target_shape = shape[:1] + [sequence_length] + shape[1:] + non_seq_target[fname] = tf.broadcast_to( + tf.expand_dims(inputs[fname], 1), target_shape + ) + inputs = {**inputs, **non_seq_target} + + return inputs def compute_output_shape( self, input_shape: Dict[str, tf.TensorShape] @@ -912,6 +956,7 @@ def compute_output_shape( for k in input_shape: if k in self.sequence_schema.column_names: sequence_length = input_shape[k][1] + break context_shapes = {} for k in input_shape: @@ -934,6 +979,7 @@ def compute_mask(self, inputs: TabularData, mask: Optional[TabularData] = None): for k in mask: if mask[k] is not None and k in self.sequence_schema.column_names: sequence_mask = mask[k] + break # no sequence mask found if sequence_mask is None: @@ -945,9 +991,7 @@ def compute_mask(self, inputs: TabularData, mask: Optional[TabularData] = None): if mask[k] is None and k in self.context_schema.column_names: masks_context[k] = sequence_mask - masks_other = self._broadcast(inputs, mask) - - new_mask = {**masks_other, **masks_context} + new_mask = {**mask, **masks_context} return new_mask diff --git a/merlin/models/tf/transforms/sequence.py b/merlin/models/tf/transforms/sequence.py index 50700052da..7457d05714 100644 --- a/merlin/models/tf/transforms/sequence.py +++ b/merlin/models/tf/transforms/sequence.py @@ -20,7 +20,6 @@ from merlin.models.tf.core.base import Block, BlockType, PredictionOutput from merlin.models.tf.core.combinators import TabularBlock -from merlin.models.tf.core.prediction import Prediction from merlin.models.tf.transforms.tensor import ListToRagged from merlin.models.tf.typing import TabularData from merlin.models.tf.utils import tf_utils @@ -400,7 +399,7 @@ class SequenceTargetAsInput(SequenceTransform): @tf.function def call( self, inputs: TabularData, targets=None, training=False, testing=False, **kwargs - ) -> Prediction: + ) -> Tuple: self._check_seq_inputs_targets(inputs) new_target = tf.identity(inputs[self.target_name]) @@ -411,7 +410,7 @@ def call( else: raise ValueError("Targets should be None or a dict of tensors") - return Prediction(inputs, targets) + return (inputs, targets) @classmethod def from_config(cls, config): @@ -480,13 +479,14 @@ def compute_mask(self, inputs, mask=None): self.target_mask = self._generate_target_mask(item_id_seq) inputs_mask = dict() - for k, v in inputs.items(): + for k in inputs: if k in self.schema.column_names: inputs_mask[k] = self.target_mask else: inputs_mask[k] = None - return (inputs_mask, self.target_mask) + targets_mask = dict({self.target_name: self.target_mask}) + return (inputs_mask, targets_mask) def _generate_target_mask(self, ids_seq: tf.RaggedTensor) -> tf.RaggedTensor: """Generates a target mask according to the defined probability and diff --git a/merlin/models/tf/transforms/tensor.py b/merlin/models/tf/transforms/tensor.py index 2503b080f0..a6aaed18d9 100644 --- a/merlin/models/tf/transforms/tensor.py +++ b/merlin/models/tf/transforms/tensor.py @@ -84,7 +84,7 @@ def call(self, inputs: TabularData, **kwargs) -> TabularData: for name, val in inputs.items(): is_ragged = True - if name in self.schema: + if name in self.schema.column_names: val_count = self.schema[name].properties.get("value_count") if ( val_count @@ -101,6 +101,13 @@ def call(self, inputs: TabularData, **kwargs) -> TabularData: elif isinstance(val, tf.RaggedTensor): ragged = val else: + # Expanding / setting last dim of non-list features to be 1D + if ( + name in self.schema.column_names + and not self.schema[name].is_list + and not self.schema[name].is_ragged + ): + val = tf.reshape(val, (-1, 1)) outputs[name] = val continue diff --git a/tests/unit/tf/transformers/test_block.py b/tests/unit/tf/transformers/test_block.py index 76dc1a2d2e..35d4e59a61 100644 --- a/tests/unit/tf/transformers/test_block.py +++ b/tests/unit/tf/transformers/test_block.py @@ -408,3 +408,57 @@ def _metrics_almost_equal(metrics1, metrics2): # Ensures metrics masking only last positions are different then the ones # considering all positions assert not _metrics_almost_equal(metrics_all_positions1, metrics_last_positions) + + +@pytest.mark.parametrize("run_eagerly", [True, False]) +def test_transformer_model_with_masking_and_broadcast_to_sequence( + sequence_testing_data: Dataset, run_eagerly: bool +): + schema = sequence_testing_data.schema + seq_schema = schema.select_by_name(["item_id_seq", "categories", "item_age_days_norm"]) + context_schema = schema.select_by_name(["user_country", "user_age"]) + sequence_testing_data.schema = seq_schema + context_schema + + target = schema.select_by_tag(Tags.ITEM_ID).column_names[0] + item_id_name = schema.select_by_tag(Tags.ITEM_ID).first.properties["domain"]["name"] + + input_block = mm.InputBlockV2( + sequence_testing_data.schema, + embeddings=mm.Embeddings( + seq_schema.select_by_tag(Tags.CATEGORICAL) + + context_schema.select_by_tag(Tags.CATEGORICAL), + sequence_combiner=None, + ), + post=mm.BroadcastToSequence(context_schema, seq_schema), + ) + + dmodel = 32 + mlp_block = mm.MLPBlock([128, dmodel], activation="relu") + + dense_block = mm.SequentialBlock( + input_block, + mlp_block, + mm.GPT2Block( + d_model=dmodel, + n_head=4, + n_layer=2, + pre=mm.ReplaceMaskedEmbeddings(), + post="inference_hidden_state", + ), + ) + + mlp_block2 = mm.MLPBlock([128, dmodel], activation="relu") + + prediction_task = mm.CategoricalOutput( + to_call=input_block["categorical"][item_id_name], + ) + model = mm.Model(dense_block, mlp_block2, prediction_task) + + fit_pre = mm.SequenceMaskRandom(schema=seq_schema, target=target, masking_prob=0.3) + testing_utils.model_test( + model, + sequence_testing_data, + run_eagerly=run_eagerly, + reload_model=False, + fit_kwargs={"pre": fit_pre}, + ) diff --git a/tests/unit/tf/transforms/test_features.py b/tests/unit/tf/transforms/test_features.py index dd9a7e25eb..4aed3e4e5c 100644 --- a/tests/unit/tf/transforms/test_features.py +++ b/tests/unit/tf/transforms/test_features.py @@ -749,7 +749,7 @@ def test_different_sequence_lengths(self): "s2": tf.constant([[[1, 2], [3, 4]], [[6, 3], [2, 3]]]), } layer(inputs) - assert "sequence features must share the same sequence lengths" in str(exc_info.value) + assert "Sequential features must share the same sequence lengths" in str(exc_info.value) def test_different_sequence_lengths_ragged(self): context_schema = Schema([ColumnSchema("c1")]) @@ -762,7 +762,7 @@ def test_different_sequence_lengths_ragged(self): "s2": tf.ragged.constant([[[1, 2], [3, 4]], [[6, 3], [2, 3]]]), } layer(inputs) - assert "sequence features must share the same sequence lengths" in str(exc_info.value) + assert "Sequential features must share the same sequence lengths" in str(exc_info.value) def test_ragged_and_dense_features(self): context_schema = Schema([ColumnSchema("c1")]) @@ -775,7 +775,9 @@ def test_ragged_and_dense_features(self): "s2": tf.ragged.constant([[[1, 2], [3, 4]], [[6, 3], [2, 3]]]), } layer(inputs) - assert "sequence features must all be ragged or all dense, not both" in str(exc_info.value) + assert "Sequential features must all be ragged or all dense, not both" in str( + exc_info.value + ) def test_mask_propagation(self): masking_layer = tf.keras.layers.Masking(mask_value=0) @@ -815,7 +817,7 @@ def test_in_model(self): sequence_schema = Schema([ColumnSchema("b")]) broadcast_layer = BroadcastToSequence(context_schema, sequence_schema) - model = mm.Model(broadcast_layer) + model = mm.Model(broadcast_layer, schema=context_schema + sequence_schema) outputs = model(partially_masked_inputs) self.assertAllEqual(outputs["a"]._keras_mask, tf.ragged.constant([[True], [False, True]])) @@ -855,7 +857,42 @@ def test_sequence_static_dim(self): outputs = broadcast_layer(inputs) self.assertAllEqual(outputs["sequence_embedding"].shape, tf.TensorShape([3, None, 2])) self.assertAllEqual(outputs["context_a"].shape, tf.TensorShape([3, None, 1])) - self.assertAllEqual(outputs["context_b"].shape, tf.TensorShape([3, None])) + self.assertAllEqual(outputs["context_b"].shape, tf.TensorShape([3, None, 1])) + + +def test_broadcast_to_sequence_input_block(sequence_testing_data: Dataset): + schema = sequence_testing_data.schema + seq_schema = schema.select_by_name(["item_id_seq", "categories", "item_age_days_norm"]) + context_schema = schema.select_by_name(["user_age", "user_country"]) + sequence_testing_data.schema = seq_schema + context_schema + + input_block = mm.InputBlockV2( + sequence_testing_data.schema, + embeddings=mm.Embeddings( + seq_schema.select_by_tag(Tags.CATEGORICAL) + + context_schema.select_by_tag(Tags.CATEGORICAL), + sequence_combiner=None, + ), + post=mm.BroadcastToSequence(context_schema, seq_schema), + aggregation=None, + ) + + batch = mm.sample_batch( + sequence_testing_data, batch_size=100, include_targets=False, to_ragged=True + ) + input_batch = input_block(batch) + assert set(input_batch.keys()) == set( + ["item_id_seq", "categories", "item_age_days_norm", "user_age", "user_country"] + ) + assert set([len(v.shape) for v in input_batch.values()]) == set([3]) + assert set([tuple(list(v.shape)[:-1]) for v in input_batch.values()]) == set( + [tuple([100, None])] + ) + assert list(input_batch["item_id_seq"].shape) == [100, None, 32] + assert list(input_batch["categories"].shape) == [100, None, 16] + assert list(input_batch["item_age_days_norm"].shape) == [100, None, 1] + assert list(input_batch["user_age"].shape) == [100, None, 1] + assert list(input_batch["user_country"].shape) == [100, None, 8] @pytest.mark.parametrize( diff --git a/tests/unit/tf/transforms/test_sequence.py b/tests/unit/tf/transforms/test_sequence.py index 05f8011751..18c44cca44 100644 --- a/tests/unit/tf/transforms/test_sequence.py +++ b/tests/unit/tf/transforms/test_sequence.py @@ -162,8 +162,7 @@ def test_seq_random_masking(sequence_testing_data: Dataset): batch, _ = mm.sample_batch(sequence_testing_data, batch_size=8, process_lists=False) - output = predict_masked(batch) - output_x, output_y = output.outputs, output.targets + output_x, output_y = predict_masked(batch) output_y = output_y[target] tf.Assert(tf.reduce_all(output_y == output_x[target]), [output_y, output_x[target]]) @@ -216,8 +215,7 @@ def test_seq_mask_random_replace_embeddings( batch, _ = mm.sample_batch(sequence_testing_data, batch_size=8, process_lists=False) - output = predict_masked(batch) - inputs, targets = output.outputs, output.targets + inputs, targets = predict_masked(batch) targets = targets[target] emb = tf.keras.layers.Embedding(1000, 16)