diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 05d4112bae..e4101269b5 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -275,3 +275,12 @@ def permute_sparse_features_abstract( # expected `Sequence[Union[int, types.SymInt]]` but got `Union[int, torch.SymInt]` permuted_weights = weights.new_empty(output_size) return (permuted_lengths, permuted_indices, permuted_weights) + + +@torch.library.impl_abstract("fbgemm::segment_sum_csr") +def segment_sum_csr_abstract( + batch_size: int, csr_seg: Tensor, values: Tensor +) -> Tensor: + output_size = csr_seg.numel() - 1 + output = values.new_empty(output_size) + return output diff --git a/fbgemm_gpu/test/failures_dict.json b/fbgemm_gpu/test/failures_dict.json index 8e53be56ed..7a7ecb3334 100644 --- a/fbgemm_gpu/test/failures_dict.json +++ b/fbgemm_gpu/test/failures_dict.json @@ -439,11 +439,11 @@ "fbgemm::segment_sum_csr": { "SparseOpsTest.test_aot_dispatch_dynamic__test_segment_sum_csr": { "comment": "", - "status": "xfail" + "status": "xsuccess" }, "SparseOpsTest.test_faketensor__test_segment_sum_csr": { "comment": "", - "status": "xfail" + "status": "xsuccess" } }, "fbgemm::stacked_jagged_1d_to_dense": {