Skip to content

Commit

Permalink
Recompute linear_cache_indices for pipeline prefetching (#2147)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2147

When pipeline prefetching is enabled (`prefetch_pipeline=True`) for
`EmbeddingLocation.MANAGED_CACHING`, TBE has to update
`lxu_cache_locations` to ensure cache consistency before the backward
pass.  The `lxu_cache_locations` update requires
`linear_cache_indices` as an input.  Prior to this diff, TBE keeps
`linear_cache_indices` alive after prefetching until the tensor is
used for the `lxu_cache_locations` update.  This puts a lot of
pressure to the memory space requirement limiting the enablement of
pipeline prefetching for some models.  This diff addresses the memory
limitation issue by recomputing `linear_cache_indices` when it is
needed.

Reviewed By: jspark1105

Differential Revision: D50983176

fbshipit-source-id: 050c4bb59db4697a5d53d09b52e39e101ecd50ee
  • Loading branch information
sryap authored and facebook-github-bot committed Nov 21, 2023
1 parent 5436320 commit 37111f5
Showing 1 changed file with 13 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
record_cache_metrics: RecordCacheMetrics
uvm_cache_stats: torch.Tensor
local_uvm_cache_stats: torch.Tensor
linear_cache_indices_list: List[Tensor]

def __init__( # noqa C901
self,
Expand Down Expand Up @@ -941,6 +940,11 @@ def forward( # noqa: C901
B_offsets=vbe_metadata.B_offsets,
max_B=vbe_metadata.max_B,
)

# Storing indices and offsets for linear_cache_indices recomputation
self._indices = indices
self._offsets = offsets

self.step += 1
if len(self.timesteps_prefetched) == 0:
self._prefetch(indices, offsets)
Expand Down Expand Up @@ -1234,8 +1238,6 @@ def _prefetch(self, indices: Tensor, offsets: Tensor) -> None:
)

self.lxu_cache_locations_list.append(lxu_cache_locations)
if self.prefetch_pipeline:
self.linear_cache_indices_list.append(linear_cache_indices)

if self.gather_uvm_cache_stats:
# Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64).
Expand All @@ -1255,8 +1257,6 @@ def _prefetch_tensors_record_stream(

for t in self.lxu_cache_locations_list:
t.record_stream(forward_stream)
for t in self.linear_cache_indices_list:
t.record_stream(forward_stream)

def _update_cache_miss_counter(
self,
Expand Down Expand Up @@ -1580,8 +1580,9 @@ def _apply_cache_state(
0, device=self.current_device, dtype=torch.int32
).fill_(-1)
self.lxu_cache_locations = self.lxu_cache_locations_empty
self._indices = self.lxu_cache_locations_empty
self._offsets = self.lxu_cache_locations_empty
self.prefetch_stream: Optional[torch.cuda.Stream] = None
self.linear_cache_indices_list = []

self._init_uvm_cache_stats()

Expand Down Expand Up @@ -1796,7 +1797,12 @@ def _update_cache_counter_and_locations(
self.lxu_cache_locations,
)

linear_cache_indices = self.linear_cache_indices_list.pop(0)
# Recompute linear_cache_indices
linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
self.cache_hash_size_cumsum,
self._indices,
self._offsets,
)
lxu_cache_locations_new = torch.ops.fbgemm.lxu_cache_lookup(
linear_cache_indices,
self.lxu_cache_state,
Expand Down

0 comments on commit 37111f5

Please sign in to comment.