diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index cfea8e1f88..5751817173 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -280,7 +280,6 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module): record_cache_metrics: RecordCacheMetrics uvm_cache_stats: torch.Tensor local_uvm_cache_stats: torch.Tensor - linear_cache_indices_list: List[Tensor] def __init__( # noqa C901 self, @@ -941,6 +940,11 @@ def forward( # noqa: C901 B_offsets=vbe_metadata.B_offsets, max_B=vbe_metadata.max_B, ) + + # Storing indices and offsets for linear_cache_indices recomputation + self._indices = indices + self._offsets = offsets + self.step += 1 if len(self.timesteps_prefetched) == 0: self._prefetch(indices, offsets) @@ -1234,8 +1238,6 @@ def _prefetch(self, indices: Tensor, offsets: Tensor) -> None: ) self.lxu_cache_locations_list.append(lxu_cache_locations) - if self.prefetch_pipeline: - self.linear_cache_indices_list.append(linear_cache_indices) if self.gather_uvm_cache_stats: # Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64). @@ -1255,8 +1257,6 @@ def _prefetch_tensors_record_stream( for t in self.lxu_cache_locations_list: t.record_stream(forward_stream) - for t in self.linear_cache_indices_list: - t.record_stream(forward_stream) def _update_cache_miss_counter( self, @@ -1580,8 +1580,9 @@ def _apply_cache_state( 0, device=self.current_device, dtype=torch.int32 ).fill_(-1) self.lxu_cache_locations = self.lxu_cache_locations_empty + self._indices = self.lxu_cache_locations_empty + self._offsets = self.lxu_cache_locations_empty self.prefetch_stream: Optional[torch.cuda.Stream] = None - self.linear_cache_indices_list = [] self._init_uvm_cache_stats() @@ -1796,7 +1797,12 @@ def _update_cache_counter_and_locations( self.lxu_cache_locations, ) - linear_cache_indices = self.linear_cache_indices_list.pop(0) + # Recompute linear_cache_indices + linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( + self.cache_hash_size_cumsum, + self._indices, + self._offsets, + ) lxu_cache_locations_new = torch.ops.fbgemm.lxu_cache_lookup( linear_cache_indices, self.lxu_cache_state,