From 6360b6f1ce06ebcf2245d82e74f2798c6c87475a Mon Sep 17 00:00:00 2001 From: Gabriel Moreira Date: Wed, 15 Feb 2023 19:30:10 -0300 Subject: [PATCH 1/6] Fixed error that was causing the broadcasted context feature to have fixed size first dim in graph mode and not being compatible with the ragged sequential features --- merlin/models/tf/transforms/features.py | 23 +++++-- tests/unit/tf/transforms/test_features.py | 82 ++++++++++++++++++++++- 2 files changed, 100 insertions(+), 5 deletions(-) diff --git a/merlin/models/tf/transforms/features.py b/merlin/models/tf/transforms/features.py index 2f0406f0c5..6b8b2bee21 100644 --- a/merlin/models/tf/transforms/features.py +++ b/merlin/models/tf/transforms/features.py @@ -812,7 +812,7 @@ def reshape_categorical_input_tensor_for_encoding( @tf.keras.utils.register_keras_serializable(package="merlin.models") -class BroadcastToSequence(tf.keras.layers.Layer): +class BroadcastToSequence(Block): """Broadcast context features to match the timesteps of sequence features. This layer supports mask propagation. If the sequence features have a mask. The @@ -884,6 +884,8 @@ def _get_seq_features_shapes(self, inputs: TabularData): @tf.function def _broadcast(self, inputs, target): seq_features_shapes, sequence_length = self._get_seq_features_shapes(inputs) + first_seq_feature_name = list(seq_features_shapes.keys())[0] + first_seq_feature_value = inputs[first_seq_feature_name] if len(seq_features_shapes) > 0: non_seq_features = set(inputs.keys()).difference(set(seq_features_shapes.keys())) non_seq_target = {} @@ -891,9 +893,20 @@ def _broadcast(self, inputs, target): if fname in self.context_schema.column_names: if target[fname] is None: continue - if isinstance(sequence_length, tf.Tensor): - non_seq_target[fname] = tf.RaggedTensor.from_row_lengths( - tf.repeat(target[fname], sequence_length, axis=0), sequence_length + if isinstance(first_seq_feature_value, tf.RaggedTensor): + target_value = target[fname] + if len(target_value.shape) == 1: + target_value = tf.expand_dims(target_value, -1) + if len(target_value.shape) == 2: + target_value = tf.expand_dims(target_value, -1) + # Here broadcast the context feature in a 3D feature with compatible + # shape to the ragged sequential features + non_seq_target[fname] = ( + tf.ones_like( + target[first_seq_feature_name][:, :, 0:1], + dtype=target[fname].dtype, + ) + * target_value ) else: shape = target[fname].shape @@ -912,6 +925,7 @@ def compute_output_shape( for k in input_shape: if k in self.sequence_schema.column_names: sequence_length = input_shape[k][1] + break context_shapes = {} for k in input_shape: @@ -934,6 +948,7 @@ def compute_mask(self, inputs: TabularData, mask: Optional[TabularData] = None): for k in mask: if mask[k] is not None and k in self.sequence_schema.column_names: sequence_mask = mask[k] + break # no sequence mask found if sequence_mask is None: diff --git a/tests/unit/tf/transforms/test_features.py b/tests/unit/tf/transforms/test_features.py index dd9a7e25eb..8367986036 100644 --- a/tests/unit/tf/transforms/test_features.py +++ b/tests/unit/tf/transforms/test_features.py @@ -855,7 +855,7 @@ def test_sequence_static_dim(self): outputs = broadcast_layer(inputs) self.assertAllEqual(outputs["sequence_embedding"].shape, tf.TensorShape([3, None, 2])) self.assertAllEqual(outputs["context_a"].shape, tf.TensorShape([3, None, 1])) - self.assertAllEqual(outputs["context_b"].shape, tf.TensorShape([3, None])) + self.assertAllEqual(outputs["context_b"].shape, tf.TensorShape([3, None, 1])) @pytest.mark.parametrize( @@ -1010,3 +1010,83 @@ def test_to_target_compute_output_schema(): to_target = mm.ToTarget(schema, "label") output_schema = to_target.compute_output_schema(schema) assert "label" in output_schema.select_by_tag(Tags.TARGET).column_names + + +def test_broadcast_to_sequence_input_block(sequence_testing_data: Dataset): + schema = sequence_testing_data.schema + seq_schema = schema.select_by_name(["item_id_seq", "categories", "item_age_days_norm"]) + context_schema = schema.select_by_name(["user_age"]) + sequence_testing_data.schema = seq_schema + context_schema + + input_block = mm.InputBlockV2( + sequence_testing_data.schema, + embeddings=mm.Embeddings( + seq_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None + ), + post=mm.BroadcastToSequence(context_schema, seq_schema), + aggregation=None, + ) + + batch = mm.sample_batch( + sequence_testing_data, batch_size=100, include_targets=False, to_ragged=True + ) + input_batch = input_block(batch) + assert set(input_batch.keys()) == set( + ["item_id_seq", "categories", "item_age_days_norm", "user_age"] + ) + assert set([len(v.shape) for v in input_batch.values()]) == set([3]) + assert set([v.shape[:-1] for v in input_batch.values()]) == set([tf.TensorShape([100, None])]) + assert list(input_batch["user_age"].shape) == [100, None, 1] + + +def test_model_with_broadcast_to_sequence(sequence_testing_data: Dataset): + schema = sequence_testing_data.schema + seq_schema = schema.select_by_name(["item_id_seq", "categories", "item_age_days_norm"]) + context_schema = schema.select_by_name(["user_age"]) + sequence_testing_data.schema = seq_schema + context_schema + + target = schema.select_by_tag(Tags.ITEM_ID).column_names[0] + item_id_name = schema.select_by_tag(Tags.ITEM_ID).first.properties["domain"]["name"] + + input_block = mm.InputBlockV2( + sequence_testing_data.schema, + embeddings=mm.Embeddings( + seq_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None + ), + post=mm.BroadcastToSequence(context_schema, seq_schema), + ) + + dmodel = 32 + mlp_block = mm.MLPBlock([128, dmodel], activation="relu") + + dense_block = mm.SequentialBlock( + input_block, + mlp_block, + mm.XLNetBlock( + d_model=dmodel, + n_head=4, + n_layer=2, + pre=mm.ReplaceMaskedEmbeddings(), + post="inference_hidden_state", + ), + ) + + mlp_block2 = mm.MLPBlock([128, dmodel], activation="relu", no_activation_last_layer=True) + + prediction_task = mm.CategoricalOutput( + to_call=input_block["categorical"][item_id_name], + ) + model_transformer = mm.Model(dense_block, mlp_block2, prediction_task) + + model_transformer.compile( + run_eagerly=False, + optimizer="adam", + loss="categorical_crossentropy", + metrics=mm.TopKMetricsAggregator.default_metrics(top_ks=[4]), + ) + model_transformer.fit( + sequence_testing_data, + batch_size=512, + epochs=1, + pre=mm.SequenceMaskRandom(schema=seq_schema, target=target, masking_prob=0.3), + ) From c706a0b8cfd8586279b8bf49c5a0caae3c2a3194 Mon Sep 17 00:00:00 2001 From: Gabriel Moreira Date: Thu, 16 Feb 2023 17:28:46 -0300 Subject: [PATCH 2/6] Enforcing non-list (scalar) features to be 2D (batch size,1) if 1D or with last dim undefined (which happens in graph mode) --- merlin/models/tf/transforms/tensor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/merlin/models/tf/transforms/tensor.py b/merlin/models/tf/transforms/tensor.py index 2503b080f0..20620c3a4c 100644 --- a/merlin/models/tf/transforms/tensor.py +++ b/merlin/models/tf/transforms/tensor.py @@ -101,6 +101,9 @@ def call(self, inputs: TabularData, **kwargs) -> TabularData: elif isinstance(val, tf.RaggedTensor): ragged = val else: + # Expanding / setting last dim of non-list features to be 1D + if not self.schema[name].is_list and not self.schema[name].is_ragged: + val = tf.reshape(val, (-1, 1)) outputs[name] = val continue From f591ebe4adfd39a5a072f13687af9cdfd5693fa4 Mon Sep 17 00:00:00 2001 From: Gabriel Moreira Date: Thu, 16 Feb 2023 17:29:52 -0300 Subject: [PATCH 3/6] Making Continuous support_masking=True (to cascade mask) --- merlin/models/tf/inputs/continuous.py | 1 + 1 file changed, 1 insertion(+) diff --git a/merlin/models/tf/inputs/continuous.py b/merlin/models/tf/inputs/continuous.py index 24210247c9..f5bdcc2067 100644 --- a/merlin/models/tf/inputs/continuous.py +++ b/merlin/models/tf/inputs/continuous.py @@ -47,6 +47,7 @@ def __init__( ): if inputs is None: inputs = Tags.CONTINUOUS + self.supports_masking = True super().__init__(inputs, **kwargs) From 6c260cfd72227abed2b237a790c8308bb7725380 Mon Sep 17 00:00:00 2001 From: Gabriel Moreira Date: Thu, 16 Feb 2023 21:36:42 -0300 Subject: [PATCH 4/6] Changing BroadcastToSequence to fix some issues and simplify the masking --- merlin/models/tf/transforms/features.py | 179 +++++++++++++--------- merlin/models/tf/transforms/sequence.py | 10 +- merlin/models/tf/transforms/tensor.py | 8 +- tests/unit/tf/transformers/test_block.py | 54 +++++++ tests/unit/tf/transforms/test_features.py | 123 +++++---------- 5 files changed, 208 insertions(+), 166 deletions(-) diff --git a/merlin/models/tf/transforms/features.py b/merlin/models/tf/transforms/features.py index 6b8b2bee21..d4c415b91a 100644 --- a/merlin/models/tf/transforms/features.py +++ b/merlin/models/tf/transforms/features.py @@ -833,90 +833,121 @@ def __init__(self, context_schema: Schema, sequence_schema: Schema, **kwargs): self.sequence_schema = sequence_schema def call(self, inputs: TabularData) -> TabularData: - inputs = self._broadcast(inputs, inputs) + inputs = self._broadcast(inputs) return inputs - def _get_seq_features_shapes(self, inputs: TabularData): - inputs_sizes = {k: v.shape for k, v in inputs.items()} + def _check_sequence_features(self, inputs: TabularData): + sequence_features = self.sequence_schema.column_names - seq_features_shapes = dict() - for fname, fshape in inputs_sizes.items(): - # Saves the shapes of sequential features - if fname in self.sequence_schema.column_names: - seq_features_shapes[fname] = tuple(fshape[:2]) + if len(sequence_features) == 0: + return + + not_found_seq_features = set(sequence_features).difference(set(inputs.keys())) + if len(not_found_seq_features) > 0: + raise ValueError( + f"Some sequential features were not found in the inputs: {not_found_seq_features}" + ) sequence_length = None sequence_is_ragged = None - if len(seq_features_shapes) > 0: - for k, v in inputs.items(): - if k in self.sequence_schema.column_names: - if isinstance(v, tf.RaggedTensor): - if sequence_is_ragged is False: - raise ValueError( - "sequence features must all be ragged or all dense, not both." - ) - new_sequence_length = v.row_lengths() - sequence_is_ragged = True - else: - if sequence_is_ragged is True: - raise ValueError( - "sequence features must all be ragged or all dense, not both." - ) - new_sequence_length = [v.shape[1]] - sequence_is_ragged = False - - # check sequences lengths match - if sequence_length is not None: - sequence_lengths_equal = tf.math.reduce_all( - tf.equal(new_sequence_length, sequence_length) + for k, v in inputs.items(): + if k in sequence_features: + if isinstance(v, tf.RaggedTensor): + if sequence_is_ragged is False: + raise ValueError( + "Sequential features must all be ragged or all dense, not both." ) - tf.Assert( - sequence_lengths_equal, - [ - "sequence features must share the same sequence lengths", - (sequence_length, new_sequence_length), - ], + new_sequence_length = v.row_lengths() + sequence_is_ragged = True + else: + if sequence_is_ragged is True: + raise ValueError( + "Sequential features must all be ragged or all dense, not both." ) - sequence_length = new_sequence_length + new_sequence_length = [v.shape[1]] + sequence_is_ragged = False + + # check sequences lengths match + if sequence_length is not None: + sequence_lengths_equal = tf.math.reduce_all( + tf.equal(new_sequence_length, sequence_length) + ) + tf.Assert( + sequence_lengths_equal, + [ + "Sequential features must share the same sequence lengths", + (sequence_length, new_sequence_length), + ], + ) + sequence_length = new_sequence_length + + def _check_context_features(self, inputs: TabularData): + context_features = self.context_schema.column_names + + if len(context_features) == 0: + return + + not_found_seq_features = set(context_features).difference(set(inputs.keys())) + if len(not_found_seq_features) > 0: + raise ValueError( + f"Some contextual features were not found in the inputs: {not_found_seq_features}" + ) - return seq_features_shapes, sequence_length + for k in context_features: + v = inputs[k] + if not isinstance(v, tf.Tensor): + raise ValueError(f"A contextual feature ({k}) should be a dense tf.Tensor") + + if len(v.shape) >= 3: + raise ValueError( + f"A contextual feature ({k}) should be a 1D or " "2D tf.Tensor: {v.shape}." + ) @tf.function - def _broadcast(self, inputs, target): - seq_features_shapes, sequence_length = self._get_seq_features_shapes(inputs) - first_seq_feature_name = list(seq_features_shapes.keys())[0] - first_seq_feature_value = inputs[first_seq_feature_name] - if len(seq_features_shapes) > 0: - non_seq_features = set(inputs.keys()).difference(set(seq_features_shapes.keys())) - non_seq_target = {} - for fname in non_seq_features: - if fname in self.context_schema.column_names: - if target[fname] is None: - continue - if isinstance(first_seq_feature_value, tf.RaggedTensor): - target_value = target[fname] - if len(target_value.shape) == 1: - target_value = tf.expand_dims(target_value, -1) - if len(target_value.shape) == 2: - target_value = tf.expand_dims(target_value, -1) - # Here broadcast the context feature in a 3D feature with compatible - # shape to the ragged sequential features - non_seq_target[fname] = ( - tf.ones_like( - target[first_seq_feature_name][:, :, 0:1], - dtype=target[fname].dtype, - ) - * target_value - ) - else: - shape = target[fname].shape - target_shape = shape[:1] + sequence_length + shape[1:] - non_seq_target[fname] = tf.broadcast_to( - tf.expand_dims(target[fname], 1), target_shape - ) - target = {**target, **non_seq_target} + def _broadcast(self, inputs): + self._check_sequence_features(inputs) + self._check_context_features(inputs) + + sequence_features = self.sequence_schema.column_names + context_features = self.context_schema.column_names - return target + if len(sequence_features) == 0 and len(context_features) == 0: + return inputs + + sequence_features_values = list( + [inputs[k] for k in sequence_features if inputs[k] is not None] + ) + if len(sequence_features_values) == 0: + return inputs + template_seq_feature_value = sequence_features_values[0] + + non_seq_target = {} + for fname in context_features: + if inputs[fname] is None: + continue + + if isinstance(template_seq_feature_value, tf.RaggedTensor): + new_value = inputs[fname] + while len(new_value.shape) < len(template_seq_feature_value.shape): + new_value = tf.expand_dims(new_value, 1) + + # Here broadcast the context feature using the same shape + # of a 3D ragged sequential feature with compatible + # So that the context feature shape becomes (batch size, seq length, feature dim) + non_seq_target[fname] = ( + tf.ones_like(template_seq_feature_value[..., :1], dtype=new_value.dtype) + * new_value + ) + else: + shape = inputs[fname].shape + sequence_length = template_seq_feature_value.shape[1] + target_shape = shape[:1] + [sequence_length] + shape[1:] + non_seq_target[fname] = tf.broadcast_to( + tf.expand_dims(inputs[fname], 1), target_shape + ) + inputs = {**inputs, **non_seq_target} + + return inputs def compute_output_shape( self, input_shape: Dict[str, tf.TensorShape] @@ -960,9 +991,7 @@ def compute_mask(self, inputs: TabularData, mask: Optional[TabularData] = None): if mask[k] is None and k in self.context_schema.column_names: masks_context[k] = sequence_mask - masks_other = self._broadcast(inputs, mask) - - new_mask = {**masks_other, **masks_context} + new_mask = {**mask, **masks_context} return new_mask diff --git a/merlin/models/tf/transforms/sequence.py b/merlin/models/tf/transforms/sequence.py index 50700052da..7457d05714 100644 --- a/merlin/models/tf/transforms/sequence.py +++ b/merlin/models/tf/transforms/sequence.py @@ -20,7 +20,6 @@ from merlin.models.tf.core.base import Block, BlockType, PredictionOutput from merlin.models.tf.core.combinators import TabularBlock -from merlin.models.tf.core.prediction import Prediction from merlin.models.tf.transforms.tensor import ListToRagged from merlin.models.tf.typing import TabularData from merlin.models.tf.utils import tf_utils @@ -400,7 +399,7 @@ class SequenceTargetAsInput(SequenceTransform): @tf.function def call( self, inputs: TabularData, targets=None, training=False, testing=False, **kwargs - ) -> Prediction: + ) -> Tuple: self._check_seq_inputs_targets(inputs) new_target = tf.identity(inputs[self.target_name]) @@ -411,7 +410,7 @@ def call( else: raise ValueError("Targets should be None or a dict of tensors") - return Prediction(inputs, targets) + return (inputs, targets) @classmethod def from_config(cls, config): @@ -480,13 +479,14 @@ def compute_mask(self, inputs, mask=None): self.target_mask = self._generate_target_mask(item_id_seq) inputs_mask = dict() - for k, v in inputs.items(): + for k in inputs: if k in self.schema.column_names: inputs_mask[k] = self.target_mask else: inputs_mask[k] = None - return (inputs_mask, self.target_mask) + targets_mask = dict({self.target_name: self.target_mask}) + return (inputs_mask, targets_mask) def _generate_target_mask(self, ids_seq: tf.RaggedTensor) -> tf.RaggedTensor: """Generates a target mask according to the defined probability and diff --git a/merlin/models/tf/transforms/tensor.py b/merlin/models/tf/transforms/tensor.py index 20620c3a4c..a6aaed18d9 100644 --- a/merlin/models/tf/transforms/tensor.py +++ b/merlin/models/tf/transforms/tensor.py @@ -84,7 +84,7 @@ def call(self, inputs: TabularData, **kwargs) -> TabularData: for name, val in inputs.items(): is_ragged = True - if name in self.schema: + if name in self.schema.column_names: val_count = self.schema[name].properties.get("value_count") if ( val_count @@ -102,7 +102,11 @@ def call(self, inputs: TabularData, **kwargs) -> TabularData: ragged = val else: # Expanding / setting last dim of non-list features to be 1D - if not self.schema[name].is_list and not self.schema[name].is_ragged: + if ( + name in self.schema.column_names + and not self.schema[name].is_list + and not self.schema[name].is_ragged + ): val = tf.reshape(val, (-1, 1)) outputs[name] = val continue diff --git a/tests/unit/tf/transformers/test_block.py b/tests/unit/tf/transformers/test_block.py index 76dc1a2d2e..35d4e59a61 100644 --- a/tests/unit/tf/transformers/test_block.py +++ b/tests/unit/tf/transformers/test_block.py @@ -408,3 +408,57 @@ def _metrics_almost_equal(metrics1, metrics2): # Ensures metrics masking only last positions are different then the ones # considering all positions assert not _metrics_almost_equal(metrics_all_positions1, metrics_last_positions) + + +@pytest.mark.parametrize("run_eagerly", [True, False]) +def test_transformer_model_with_masking_and_broadcast_to_sequence( + sequence_testing_data: Dataset, run_eagerly: bool +): + schema = sequence_testing_data.schema + seq_schema = schema.select_by_name(["item_id_seq", "categories", "item_age_days_norm"]) + context_schema = schema.select_by_name(["user_country", "user_age"]) + sequence_testing_data.schema = seq_schema + context_schema + + target = schema.select_by_tag(Tags.ITEM_ID).column_names[0] + item_id_name = schema.select_by_tag(Tags.ITEM_ID).first.properties["domain"]["name"] + + input_block = mm.InputBlockV2( + sequence_testing_data.schema, + embeddings=mm.Embeddings( + seq_schema.select_by_tag(Tags.CATEGORICAL) + + context_schema.select_by_tag(Tags.CATEGORICAL), + sequence_combiner=None, + ), + post=mm.BroadcastToSequence(context_schema, seq_schema), + ) + + dmodel = 32 + mlp_block = mm.MLPBlock([128, dmodel], activation="relu") + + dense_block = mm.SequentialBlock( + input_block, + mlp_block, + mm.GPT2Block( + d_model=dmodel, + n_head=4, + n_layer=2, + pre=mm.ReplaceMaskedEmbeddings(), + post="inference_hidden_state", + ), + ) + + mlp_block2 = mm.MLPBlock([128, dmodel], activation="relu") + + prediction_task = mm.CategoricalOutput( + to_call=input_block["categorical"][item_id_name], + ) + model = mm.Model(dense_block, mlp_block2, prediction_task) + + fit_pre = mm.SequenceMaskRandom(schema=seq_schema, target=target, masking_prob=0.3) + testing_utils.model_test( + model, + sequence_testing_data, + run_eagerly=run_eagerly, + reload_model=False, + fit_kwargs={"pre": fit_pre}, + ) diff --git a/tests/unit/tf/transforms/test_features.py b/tests/unit/tf/transforms/test_features.py index 8367986036..876818f735 100644 --- a/tests/unit/tf/transforms/test_features.py +++ b/tests/unit/tf/transforms/test_features.py @@ -749,7 +749,7 @@ def test_different_sequence_lengths(self): "s2": tf.constant([[[1, 2], [3, 4]], [[6, 3], [2, 3]]]), } layer(inputs) - assert "sequence features must share the same sequence lengths" in str(exc_info.value) + assert "Sequential features must share the same sequence lengths" in str(exc_info.value) def test_different_sequence_lengths_ragged(self): context_schema = Schema([ColumnSchema("c1")]) @@ -762,7 +762,7 @@ def test_different_sequence_lengths_ragged(self): "s2": tf.ragged.constant([[[1, 2], [3, 4]], [[6, 3], [2, 3]]]), } layer(inputs) - assert "sequence features must share the same sequence lengths" in str(exc_info.value) + assert "Sequential features must share the same sequence lengths" in str(exc_info.value) def test_ragged_and_dense_features(self): context_schema = Schema([ColumnSchema("c1")]) @@ -775,7 +775,9 @@ def test_ragged_and_dense_features(self): "s2": tf.ragged.constant([[[1, 2], [3, 4]], [[6, 3], [2, 3]]]), } layer(inputs) - assert "sequence features must all be ragged or all dense, not both" in str(exc_info.value) + assert "Sequential features must all be ragged or all dense, not both" in str( + exc_info.value + ) def test_mask_propagation(self): masking_layer = tf.keras.layers.Masking(mask_value=0) @@ -815,7 +817,7 @@ def test_in_model(self): sequence_schema = Schema([ColumnSchema("b")]) broadcast_layer = BroadcastToSequence(context_schema, sequence_schema) - model = mm.Model(broadcast_layer) + model = mm.Model(broadcast_layer, schema=context_schema + sequence_schema) outputs = model(partially_masked_inputs) self.assertAllEqual(outputs["a"]._keras_mask, tf.ragged.constant([[True], [False, True]])) @@ -858,6 +860,39 @@ def test_sequence_static_dim(self): self.assertAllEqual(outputs["context_b"].shape, tf.TensorShape([3, None, 1])) +def test_broadcast_to_sequence_input_block(sequence_testing_data: Dataset): + schema = sequence_testing_data.schema + seq_schema = schema.select_by_name(["item_id_seq", "categories", "item_age_days_norm"]) + context_schema = schema.select_by_name(["user_age", "user_country"]) + sequence_testing_data.schema = seq_schema + context_schema + + input_block = mm.InputBlockV2( + sequence_testing_data.schema, + embeddings=mm.Embeddings( + seq_schema.select_by_tag(Tags.CATEGORICAL) + + context_schema.select_by_tag(Tags.CATEGORICAL), + sequence_combiner=None, + ), + post=mm.BroadcastToSequence(context_schema, seq_schema), + aggregation=None, + ) + + batch = mm.sample_batch( + sequence_testing_data, batch_size=100, include_targets=False, to_ragged=True + ) + input_batch = input_block(batch) + assert set(input_batch.keys()) == set( + ["item_id_seq", "categories", "item_age_days_norm", "user_age", "user_country"] + ) + assert set([len(v.shape) for v in input_batch.values()]) == set([3]) + assert set([v.shape[:-1] for v in input_batch.values()]) == set([tf.TensorShape([100, None])]) + assert list(input_batch["item_id_seq"].shape) == [100, None, 32] + assert list(input_batch["categories"].shape) == [100, None, 16] + assert list(input_batch["item_age_days_norm"].shape) == [100, None, 1] + assert list(input_batch["user_age"].shape) == [100, None, 1] + assert list(input_batch["user_country"].shape) == [100, None, 8] + + @pytest.mark.parametrize( "only_selected_in_schema", [False, True], @@ -1010,83 +1045,3 @@ def test_to_target_compute_output_schema(): to_target = mm.ToTarget(schema, "label") output_schema = to_target.compute_output_schema(schema) assert "label" in output_schema.select_by_tag(Tags.TARGET).column_names - - -def test_broadcast_to_sequence_input_block(sequence_testing_data: Dataset): - schema = sequence_testing_data.schema - seq_schema = schema.select_by_name(["item_id_seq", "categories", "item_age_days_norm"]) - context_schema = schema.select_by_name(["user_age"]) - sequence_testing_data.schema = seq_schema + context_schema - - input_block = mm.InputBlockV2( - sequence_testing_data.schema, - embeddings=mm.Embeddings( - seq_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None - ), - post=mm.BroadcastToSequence(context_schema, seq_schema), - aggregation=None, - ) - - batch = mm.sample_batch( - sequence_testing_data, batch_size=100, include_targets=False, to_ragged=True - ) - input_batch = input_block(batch) - assert set(input_batch.keys()) == set( - ["item_id_seq", "categories", "item_age_days_norm", "user_age"] - ) - assert set([len(v.shape) for v in input_batch.values()]) == set([3]) - assert set([v.shape[:-1] for v in input_batch.values()]) == set([tf.TensorShape([100, None])]) - assert list(input_batch["user_age"].shape) == [100, None, 1] - - -def test_model_with_broadcast_to_sequence(sequence_testing_data: Dataset): - schema = sequence_testing_data.schema - seq_schema = schema.select_by_name(["item_id_seq", "categories", "item_age_days_norm"]) - context_schema = schema.select_by_name(["user_age"]) - sequence_testing_data.schema = seq_schema + context_schema - - target = schema.select_by_tag(Tags.ITEM_ID).column_names[0] - item_id_name = schema.select_by_tag(Tags.ITEM_ID).first.properties["domain"]["name"] - - input_block = mm.InputBlockV2( - sequence_testing_data.schema, - embeddings=mm.Embeddings( - seq_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None - ), - post=mm.BroadcastToSequence(context_schema, seq_schema), - ) - - dmodel = 32 - mlp_block = mm.MLPBlock([128, dmodel], activation="relu") - - dense_block = mm.SequentialBlock( - input_block, - mlp_block, - mm.XLNetBlock( - d_model=dmodel, - n_head=4, - n_layer=2, - pre=mm.ReplaceMaskedEmbeddings(), - post="inference_hidden_state", - ), - ) - - mlp_block2 = mm.MLPBlock([128, dmodel], activation="relu", no_activation_last_layer=True) - - prediction_task = mm.CategoricalOutput( - to_call=input_block["categorical"][item_id_name], - ) - model_transformer = mm.Model(dense_block, mlp_block2, prediction_task) - - model_transformer.compile( - run_eagerly=False, - optimizer="adam", - loss="categorical_crossentropy", - metrics=mm.TopKMetricsAggregator.default_metrics(top_ks=[4]), - ) - model_transformer.fit( - sequence_testing_data, - batch_size=512, - epochs=1, - pre=mm.SequenceMaskRandom(schema=seq_schema, target=target, masking_prob=0.3), - ) From 0512f6abed69ab2c5f95dd00f2cee793d99d0abe Mon Sep 17 00:00:00 2001 From: Gabriel Moreira Date: Fri, 17 Feb 2023 10:24:00 -0300 Subject: [PATCH 5/6] Fixed tests --- tests/unit/tf/transforms/test_sequence.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/unit/tf/transforms/test_sequence.py b/tests/unit/tf/transforms/test_sequence.py index 05f8011751..18c44cca44 100644 --- a/tests/unit/tf/transforms/test_sequence.py +++ b/tests/unit/tf/transforms/test_sequence.py @@ -162,8 +162,7 @@ def test_seq_random_masking(sequence_testing_data: Dataset): batch, _ = mm.sample_batch(sequence_testing_data, batch_size=8, process_lists=False) - output = predict_masked(batch) - output_x, output_y = output.outputs, output.targets + output_x, output_y = predict_masked(batch) output_y = output_y[target] tf.Assert(tf.reduce_all(output_y == output_x[target]), [output_y, output_x[target]]) @@ -216,8 +215,7 @@ def test_seq_mask_random_replace_embeddings( batch, _ = mm.sample_batch(sequence_testing_data, batch_size=8, process_lists=False) - output = predict_masked(batch) - inputs, targets = output.outputs, output.targets + inputs, targets = predict_masked(batch) targets = targets[target] emb = tf.keras.layers.Embedding(1000, 16) From c5e49b0664c7889a267c09e3030359177c5a28d6 Mon Sep 17 00:00:00 2001 From: Gabriel Moreira Date: Fri, 17 Feb 2023 12:20:17 -0300 Subject: [PATCH 6/6] Fixed test --- tests/unit/tf/transforms/test_features.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/tf/transforms/test_features.py b/tests/unit/tf/transforms/test_features.py index 876818f735..4aed3e4e5c 100644 --- a/tests/unit/tf/transforms/test_features.py +++ b/tests/unit/tf/transforms/test_features.py @@ -885,7 +885,9 @@ def test_broadcast_to_sequence_input_block(sequence_testing_data: Dataset): ["item_id_seq", "categories", "item_age_days_norm", "user_age", "user_country"] ) assert set([len(v.shape) for v in input_batch.values()]) == set([3]) - assert set([v.shape[:-1] for v in input_batch.values()]) == set([tf.TensorShape([100, None])]) + assert set([tuple(list(v.shape)[:-1]) for v in input_batch.values()]) == set( + [tuple([100, None])] + ) assert list(input_batch["item_id_seq"].shape) == [100, None, 32] assert list(input_batch["categories"].shape) == [100, None, 16] assert list(input_batch["item_age_days_norm"].shape) == [100, None, 1]