Skip to content

Commit

Permalink
Add impl_abstract to segment_sum_csr (#2132)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2132

Add FakeTensor support for segement_sum_csr by adding impl_abstract, following: https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9

Reviewed By: bdhirsh

Differential Revision: D51296192

fbshipit-source-id: 8918ddc45e1ba570c8148c3ac172a4d96240e010
  • Loading branch information
Microve authored and facebook-github-bot committed Nov 15, 2023
1 parent 09ab470 commit 975cb01
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
9 changes: 9 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,12 @@ def permute_sparse_features_abstract(
# expected `Sequence[Union[int, types.SymInt]]` but got `Union[int, torch.SymInt]`
permuted_weights = weights.new_empty(output_size)
return (permuted_lengths, permuted_indices, permuted_weights)


@torch.library.impl_abstract("fbgemm::segment_sum_csr")
def segment_sum_csr_abstract(
batch_size: int, csr_seg: Tensor, values: Tensor
) -> Tensor:
output_size = csr_seg.numel() - 1
output = values.new_empty(output_size)
return output
4 changes: 2 additions & 2 deletions fbgemm_gpu/test/failures_dict.json
Original file line number Diff line number Diff line change
Expand Up @@ -439,11 +439,11 @@
"fbgemm::segment_sum_csr": {
"SparseOpsTest.test_aot_dispatch_dynamic__test_segment_sum_csr": {
"comment": "",
"status": "xfail"
"status": "xsuccess"
},
"SparseOpsTest.test_faketensor__test_segment_sum_csr": {
"comment": "",
"status": "xfail"
"status": "xsuccess"
}
},
"fbgemm::stacked_jagged_1d_to_dense": {
Expand Down

0 comments on commit 975cb01

Please sign in to comment.