diff --git a/src/torchaudio/functional/filtering.py b/src/torchaudio/functional/filtering.py
index 1628d82c33..541c56c475 100644
--- a/src/torchaudio/functional/filtering.py
+++ b/src/torchaudio/functional/filtering.py
@@ -1662,8 +1662,8 @@ def vad(
             flushedLen_ns = (measures_len - num_measures_to_flush) * measure_period_ns
             break
     # end for window
-    if not has_triggered:
-        return waveform[..., :0].view(shape[:-1] + torch.Size([0]))
+    if not has_triggered and shape[-1] >= fixed_pre_trigger_len_ns:
+        return waveform[..., :fixed_pre_trigger_len_ns].view(shape[:-1] + torch.Size([fixed_pre_trigger_len_ns]))
 
     res = waveform[:, max(pos - samplesLen_ns + flushedLen_ns, 0) :]
     # unpack batch
diff --git a/test/torchaudio_unittest/transforms/transforms_test_impl.py b/test/torchaudio_unittest/transforms/transforms_test_impl.py
index 2e70ab4ad3..0120c997c0 100644
--- a/test/torchaudio_unittest/transforms/transforms_test_impl.py
+++ b/test/torchaudio_unittest/transforms/transforms_test_impl.py
@@ -481,15 +481,20 @@ def test_specaugment(self, n_time_masks, time_mask_param, n_freq_masks, freq_mas
 
     @parameterized.expand(
         [
-            ((32000,), (0,), 16000),
-            ((1, 32000), (1, 0), 32000),
-            ((2, 44100), (2, 0), 32000),
-            ((2, 2, 44100), (2, 2, 0), 32000),
+            ((32000,), (0,), 16000, 0.0),
+            ((1, 32000), (1, 0), 32000, 0.0),
+            ((2, 44100), (2, 0), 32000, 0.0),
+            ((2, 2, 44100), (2, 2, 0), 32000, 0.0),
+            ((32000,), (16000,), 16000, 1.0),
+            ((32000,), (32000,), 16000, 4.0),
+            ((1, 32000), (1, 32000), 32000, 1.0),
+            ((2, 44100), (2, 32000), 32000, 1.0),
+            ((2, 2, 44100), (2, 2, 32000), 32000, 1.0),
         ]
     )
-    def test_vad_on_zero_audio(self, input_shape, output_shape, sample_rate: int):
-        """VAD should return zero when input is zero Tensor"""
+    def test_vad_on_zero_audio(self, input_shape, output_shape, sample_rate: int, pre_trigger_time: float):
+        """VAD should return zero when input is zero Tensor when pre_trigger_time=0"""
         inpt = torch.zeros(input_shape, dtype=self.dtype, device=self.device)
         expected_output = torch.zeros(output_shape, dtype=self.dtype, device=self.device)
-        result = T.Vad(sample_rate)(inpt)
+        result = T.Vad(sample_rate, pre_trigger_time=pre_trigger_time)(inpt)
         self.assertEqual(result, expected_output)