Skip to content

Commit

Permalink
#grain Fix bug in FastFirstFitPacking handling of scalar numpy arrays.
Browse files Browse the repository at this point in the history
This was causing us to produce scalar packed results, rather than results packed to the proper packing length.

PiperOrigin-RevId: 726668569
  • Loading branch information
aaudiber authored and copybara-github committed Feb 14, 2025
1 parent a5fef0d commit af37d2e
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions grain/_src/python/dataset/transformations/testing_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,68 @@ def test_pack_sequences_two_dimensional_features(self):
num_packing_bins=2,
)

@parameterized.product(
convert_input_to_np=[True, False],
)
def test_meta_features(self, convert_input_to_np: bool):
input_elements = [
{
"inputs": np.asarray([1, 2, 3]),
"targets": np.asarray([10]),
"meta_feature": 3,
},
{
"inputs": np.asarray([4, 5]),
"targets": np.asarray([20, 30, 40]),
"meta_feature": 7,
},
{
"inputs": np.asarray([6]),
"targets": np.asarray([50, 60]),
"meta_feature": 5,
},
]
length_struct = {"inputs": 3, "targets": 4, "meta_feature": 3}

expected_elements = [
{
"inputs": [1, 2, 3],
"targets": [10, 0, 0, 0],
"inputs_segment_ids": [1, 1, 1],
"targets_segment_ids": [1, 0, 0, 0],
"inputs_positions": [0, 1, 2],
"targets_positions": [0, 0, 0, 0],
"meta_feature": [3, 0, 0],
},
{
"inputs": [4, 5, 0],
"targets": [20, 30, 40, 0],
"inputs_segment_ids": [1, 1, 0],
"targets_segment_ids": [1, 1, 1, 0],
"inputs_positions": [0, 1, 0],
"targets_positions": [0, 1, 2, 0],
"meta_feature": [7, 0, 0],
},
{
"inputs": [6, 0, 0],
"targets": [50, 60, 0, 0],
"inputs_segment_ids": [1, 0, 0],
"targets_segment_ids": [1, 1, 0, 0],
"inputs_positions": [0, 0, 0],
"targets_positions": [0, 1, 0, 0],
"meta_feature": [5, 0, 0],
},
]
_common_test_body(
input_elements,
expected_elements,
length_struct,
kwargs=self.kwargs,
num_packing_bins=3,
meta_features=["meta_feature"],
convert_input_to_np=convert_input_to_np,
)

@parameterized.parameters(
{"restore_at_step": 0},
{"restore_at_step": 1},
Expand Down

0 comments on commit af37d2e

Please sign in to comment.