diff --git a/tests/unit/tf/models/test_base.py b/tests/unit/tf/models/test_base.py index 9f0fc6d2eb..d408ac38d1 100644 --- a/tests/unit/tf/models/test_base.py +++ b/tests/unit/tf/models/test_base.py @@ -923,11 +923,11 @@ def test_retrieval_model_query_candidate(ecommerce_data: Dataset, run_eagerly=Tr assert isinstance(reloaded_model.query_encoder, mm.EmbeddingEncoder) assert isinstance(reloaded_model.candidate_encoder, mm.EmbeddingEncoder) - queries = model.query_embeddings(ecommerce_data, batch_size=10, index=Tags.USER_ID).compute() + queries = model.query_embeddings(ecommerce_data, batch_size=16, index=Tags.USER_ID).compute() _check_embeddings(queries, 100, "user_id") candidates = model.candidate_embeddings( - ecommerce_data, batch_size=10, index=candidate + ecommerce_data, batch_size=16, index=candidate ).compute() _check_embeddings(candidates, 100, "item_id") diff --git a/tests/unit/tf/models/test_retrieval.py b/tests/unit/tf/models/test_retrieval.py index 521a974102..708874b1e5 100644 --- a/tests/unit/tf/models/test_retrieval.py +++ b/tests/unit/tf/models/test_retrieval.py @@ -898,11 +898,11 @@ def test_two_tower_v2_export_embeddings( model, _ = testing_utils.model_test(model, ecommerce_data, reload_model=False) - queries = model.query_embeddings(ecommerce_data, batch_size=10, index=Tags.USER_ID).compute() + queries = model.query_embeddings(ecommerce_data, batch_size=16, index=Tags.USER_ID).compute() _check_embeddings(queries, 100, 8, "user_id") candidates = model.candidate_embeddings( - ecommerce_data, batch_size=10, index=Tags.ITEM_ID + ecommerce_data, batch_size=16, index=Tags.ITEM_ID ).compute() _check_embeddings(candidates, 100, 8, "item_id") @@ -918,11 +918,11 @@ def test_mf_v2_export_embeddings( model, _ = testing_utils.model_test(model, ecommerce_data, reload_model=False) - queries = model.query_embeddings(ecommerce_data, batch_size=10, index=Tags.USER_ID).compute() + queries = model.query_embeddings(ecommerce_data, batch_size=16, index=Tags.USER_ID).compute() _check_embeddings(queries, 100, 8, "user_id") candidates = model.candidate_embeddings( - ecommerce_data, batch_size=10, index=Tags.ITEM_ID + ecommerce_data, batch_size=16, index=Tags.ITEM_ID ).compute() _check_embeddings(candidates, 100, 8, "item_id") diff --git a/tests/unit/tf/transformers/test_block.py b/tests/unit/tf/transformers/test_block.py index 3dd89beff7..39b4c87494 100644 --- a/tests/unit/tf/transformers/test_block.py +++ b/tests/unit/tf/transformers/test_block.py @@ -182,7 +182,7 @@ def test_transformer_as_classification_model(sequence_testing_data: Dataset, run batch = loader.peek()[0] outputs = model(batch) - assert list(outputs.shape) == [50, 63] + assert list(outputs.shape) == [64, 63] testing_utils.model_test(model, loader, run_eagerly=run_eagerly) @@ -223,7 +223,7 @@ def classification_loader(sequence_testing_data: Dataset): sequence_testing_data.schema = schema dataloader = mm.Loader( sequence_testing_data, - batch_size=50, + batch_size=64, ).map(mm.ToTarget(schema, "user_country", one_hot=True)) return dataloader, dataloader.output_schema