From 975cb0156db3f3b12280ef2198c9e972678bbb61 Mon Sep 17 00:00:00 2001 From: Shuai Yang Date: Tue, 14 Nov 2023 17:48:56 -0800 Subject: [PATCH] Add impl_abstract to segment_sum_csr (#2132) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2132 Add FakeTensor support for segement_sum_csr by adding impl_abstract, following: https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9 Reviewed By: bdhirsh Differential Revision: D51296192 fbshipit-source-id: 8918ddc45e1ba570c8148c3ac172a4d96240e010 --- fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 9 +++++++++ fbgemm_gpu/test/failures_dict.json | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 05d4112bae..e4101269b5 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -275,3 +275,12 @@ def permute_sparse_features_abstract( # expected `Sequence[Union[int, types.SymInt]]` but got `Union[int, torch.SymInt]` permuted_weights = weights.new_empty(output_size) return (permuted_lengths, permuted_indices, permuted_weights) + + +@torch.library.impl_abstract("fbgemm::segment_sum_csr") +def segment_sum_csr_abstract( + batch_size: int, csr_seg: Tensor, values: Tensor +) -> Tensor: + output_size = csr_seg.numel() - 1 + output = values.new_empty(output_size) + return output diff --git a/fbgemm_gpu/test/failures_dict.json b/fbgemm_gpu/test/failures_dict.json index 8e53be56ed..7a7ecb3334 100644 --- a/fbgemm_gpu/test/failures_dict.json +++ b/fbgemm_gpu/test/failures_dict.json @@ -439,11 +439,11 @@ "fbgemm::segment_sum_csr": { "SparseOpsTest.test_aot_dispatch_dynamic__test_segment_sum_csr": { "comment": "", - "status": "xfail" + "status": "xsuccess" }, "SparseOpsTest.test_faketensor__test_segment_sum_csr": { "comment": "", - "status": "xfail" + "status": "xsuccess" } }, "fbgemm::stacked_jagged_1d_to_dense": {