diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py
index 56cec4db89e64..cbc3668997817 100644
--- a/tests/lora/test_baichuan.py
+++ b/tests/lora/test_baichuan.py
@@ -63,12 +63,11 @@ def test_baichuan_lora(baichuan_lora_files):
         assert output2[i] == expected_lora_output[i]
 
 
-@pytest.mark.skip("Requires multiple GPUs")
 @pytest.mark.parametrize("fully_sharded", [True, False])
-def test_baichuan_tensor_parallel_equality(baichuan_lora_files, fully_sharded):
-    # Cannot use as it will initialize torch.cuda too early...
-    # if torch.cuda.device_count() < 4:
-    #     pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
+def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
+                                           num_gpus_available, fully_sharded):
+    if num_gpus_available < 4:
+        pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
 
     llm_tp1 = vllm.LLM(MODEL_PATH,
                        enable_lora=True,
diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py
index 133e0d4514a6d..5636c96435024 100644
--- a/tests/lora/test_quant_model.py
+++ b/tests/lora/test_quant_model.py
@@ -71,10 +71,10 @@ def format_prompt_tuples(prompt):
 
 @pytest.mark.parametrize("model", MODELS)
 @pytest.mark.parametrize("tp_size", [1])
-def test_quant_model_lora(tinyllama_lora_files, model, tp_size):
-    # Cannot use as it will initialize torch.cuda too early...
-    # if torch.cuda.device_count() < tp_size:
-    #     pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
+def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
+                          tp_size):
+    if num_gpus_available < tp_size:
+        pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
 
     llm = vllm.LLM(
         model=model.model_path,
@@ -164,11 +164,10 @@ def expect_match(output, expected_output):
 
 
 @pytest.mark.parametrize("model", MODELS)
-@pytest.mark.skip("Requires multiple GPUs")
-def test_quant_model_tp_equality(tinyllama_lora_files, model):
-    # Cannot use as it will initialize torch.cuda too early...
-    # if torch.cuda.device_count() < 2:
-    #     pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
+def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
+                                 model):
+    if num_gpus_available < 2:
+        pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
 
     llm_tp1 = vllm.LLM(
         model=model.model_path,
diff --git a/tests/models/decoder_only/language/test_phimoe.py b/tests/models/decoder_only/language/test_phimoe.py
index dbdf5a1b934a6..89afbcf1c03ac 100644
--- a/tests/models/decoder_only/language/test_phimoe.py
+++ b/tests/models/decoder_only/language/test_phimoe.py
@@ -7,6 +7,7 @@
 
 from vllm.utils import is_cpu
 
+from ....utils import large_gpu_test
 from ...utils import check_logprobs_close
 
 MODELS = [
@@ -69,20 +70,10 @@ def test_phimoe_routing_function():
         assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"])
 
 
-def get_gpu_memory():
-    try:
-        props = torch.cuda.get_device_properties(torch.cuda.current_device())
-        gpu_memory = props.total_memory / (1024**3)
-        return gpu_memory
-    except Exception:
-        return 0
-
-
 @pytest.mark.skipif(condition=is_cpu(),
                     reason="This test takes a lot time to run on CPU, "
                     "and vllm CI's disk space is not enough for this model.")
-@pytest.mark.skipif(condition=get_gpu_memory() < 100,
-                    reason="Skip this test if GPU memory is insufficient.")
+@large_gpu_test(min_gb=80)
 @pytest.mark.parametrize("model", MODELS)
 @pytest.mark.parametrize("dtype", ["bfloat16"])
 @pytest.mark.parametrize("max_tokens", [64])
diff --git a/tests/models/decoder_only/vision_language/test_llava_onevision.py b/tests/models/decoder_only/vision_language/test_llava_onevision.py
index 2c4cd3fb85297..367f25f446279 100644
--- a/tests/models/decoder_only/vision_language/test_llava_onevision.py
+++ b/tests/models/decoder_only/vision_language/test_llava_onevision.py
@@ -11,6 +11,7 @@
 
 from ....conftest import (VIDEO_ASSETS, HfRunner, PromptImageInput, VllmRunner,
                           _VideoAssets)
+from ....utils import large_gpu_test
 from ...utils import check_logprobs_close
 
 # Video test
@@ -164,9 +165,7 @@ def process(hf_inputs: BatchEncoding):
         )
 
 
-@pytest.mark.skip(
-    reason=
-    "Model is too big, test passed on L40 locally but will OOM on CI machine.")
+@large_gpu_test(min_gb=48)
 @pytest.mark.parametrize("model", models)
 @pytest.mark.parametrize(
     "size_factors",
@@ -210,9 +209,7 @@ def test_models(hf_runner, vllm_runner, video_assets, model, size_factors,
     )
 
 
-@pytest.mark.skip(
-    reason=
-    "Model is too big, test passed on L40 locally but will OOM on CI machine.")
+@large_gpu_test(min_gb=48)
 @pytest.mark.parametrize("model", models)
 @pytest.mark.parametrize(
     "sizes",
@@ -306,9 +303,7 @@ def process(hf_inputs: BatchEncoding):
         )
 
 
-@pytest.mark.skip(
-    reason=
-    "Model is too big, test passed on L40 locally but will OOM on CI machine.")
+@large_gpu_test(min_gb=48)
 @pytest.mark.parametrize("model", models)
 @pytest.mark.parametrize("dtype", ["half"])
 @pytest.mark.parametrize("max_tokens", [128])
diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py
index 072bedfc01a1f..d8a98a0f84d3b 100644
--- a/tests/models/decoder_only/vision_language/test_pixtral.py
+++ b/tests/models/decoder_only/vision_language/test_pixtral.py
@@ -17,7 +17,7 @@
 from vllm.multimodal import MultiModalDataBuiltins
 from vllm.sequence import Logprob, SampleLogprobs
 
-from ....utils import VLLM_PATH
+from ....utils import VLLM_PATH, large_gpu_test
 from ...utils import check_logprobs_close
 
 if TYPE_CHECKING:
@@ -121,10 +121,7 @@ def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs:
             for tokens, text, logprobs in json_data]
 
 
-@pytest.mark.skip(
-    reason=
-    "Model is too big, test passed on A100 locally but will OOM on CI machine."
-)
+@large_gpu_test(min_gb=80)
 @pytest.mark.parametrize("model", MODELS)
 @pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN)
 @pytest.mark.parametrize("dtype", ["bfloat16"])
@@ -157,10 +154,7 @@ def test_chat(
                          name_1="output")
 
 
-@pytest.mark.skip(
-    reason=
-    "Model is too big, test passed on A100 locally but will OOM on CI machine."
-)
+@large_gpu_test(min_gb=80)
 @pytest.mark.parametrize("model", MODELS)
 @pytest.mark.parametrize("dtype", ["bfloat16"])
 def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py
index ea09b758afc86..254185537e403 100644
--- a/tests/models/encoder_decoder/vision_language/test_mllama.py
+++ b/tests/models/encoder_decoder/vision_language/test_mllama.py
@@ -9,6 +9,7 @@
 
 from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
                           _ImageAssets)
+from ....utils import large_gpu_test
 from ...utils import check_logprobs_close
 
 _LIMIT_IMAGE_PER_PROMPT = 1
@@ -227,29 +228,26 @@ def process(hf_inputs: BatchEncoding):
         )
 
 
-SIZES = [
-    # Text only
-    [],
-    # Single-size
-    [(512, 512)],
-    # Single-size, batched
-    [(512, 512), (512, 512), (512, 512)],
-    # Multi-size, batched
-    [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
-     (1024, 1024), (512, 1536), (512, 2028)],
-    # Multi-size, batched, including text only
-    [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
-     (1024, 1024), (512, 1536), (512, 2028), None],
-    # mllama has 8 possible aspect ratios, carefully set the sizes
-    # to cover all of them
-]
-
-
-@pytest.mark.skip(
-    reason=
-    "Model is too big, test passed on L40 locally but will OOM on CI machine.")
+@large_gpu_test(min_gb=48)
 @pytest.mark.parametrize("model", models)
-@pytest.mark.parametrize("sizes", SIZES)
+@pytest.mark.parametrize(
+    "sizes",
+    [
+        # Text only
+        [],
+        # Single-size
+        [(512, 512)],
+        # Single-size, batched
+        [(512, 512), (512, 512), (512, 512)],
+        # Multi-size, batched
+        [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
+         (1024, 1024), (512, 1536), (512, 2028)],
+        # Multi-size, batched, including text only
+        [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
+         (1024, 1024), (512, 1536), (512, 2028), None],
+        # mllama has 8 possible aspect ratios, carefully set the sizes
+        # to cover all of them
+    ])
 @pytest.mark.parametrize("dtype", ["bfloat16"])
 @pytest.mark.parametrize("max_tokens", [128])
 @pytest.mark.parametrize("num_logprobs", [5])
diff --git a/tests/utils.py b/tests/utils.py
index 3eff77f396e19..49bd4f236f658 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -24,8 +24,8 @@
 from vllm.entrypoints.openai.cli_args import make_arg_parser
 from vllm.model_executor.model_loader.loader import get_model_loader
 from vllm.platforms import current_platform
-from vllm.utils import (FlexibleArgumentParser, cuda_device_count_stateless,
-                        get_open_port, is_hip)
+from vllm.utils import (FlexibleArgumentParser, GB_bytes,
+                        cuda_device_count_stateless, get_open_port, is_hip)
 
 if current_platform.is_rocm():
     from amdsmi import (amdsmi_get_gpu_vram_usage,
@@ -455,6 +455,37 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
     return wrapper
 
 
+def large_gpu_test(*, min_gb: int):
+    """
+    Decorate a test to be skipped if no GPU is available or it does not have
+    sufficient memory.
+
+    Currently, the CI machine uses L4 GPU which has 24 GB VRAM.
+    """
+    try:
+        if current_platform.is_cpu():
+            memory_gb = 0
+        else:
+            memory_gb = current_platform.get_device_total_memory() / GB_bytes
+    except Exception as e:
+        warnings.warn(
+            f"An error occurred when finding the available memory: {e}",
+            stacklevel=2,
+        )
+
+        memory_gb = 0
+
+    test_skipif = pytest.mark.skipif(
+        memory_gb < min_gb,
+        reason=f"Need at least {memory_gb}GB GPU memory to run the test.",
+    )
+
+    def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
+        return test_skipif(fork_new_process_for_each_test(f))
+
+    return wrapper
+
+
 def multi_gpu_test(*, num_gpus: int):
     """
     Decorate a test to be run only when multiple GPUs are available.
diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py
index 9b348f3e17a5f..5243f59203afc 100644
--- a/vllm/platforms/cpu.py
+++ b/vllm/platforms/cpu.py
@@ -1,3 +1,4 @@
+import psutil
 import torch
 
 from .interface import Platform, PlatformEnum
@@ -10,6 +11,10 @@ class CpuPlatform(Platform):
     def get_device_name(cls, device_id: int = 0) -> str:
         return "cpu"
 
+    @classmethod
+    def get_device_total_memory(cls, device_id: int = 0) -> int:
+        return psutil.virtual_memory().total
+
     @classmethod
     def inference_mode(cls):
         return torch.no_grad()
diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py
index a9978d5d84d7c..fa487e2f917d8 100644
--- a/vllm/platforms/cuda.py
+++ b/vllm/platforms/cuda.py
@@ -59,6 +59,13 @@ def get_physical_device_name(device_id: int = 0) -> str:
     return pynvml.nvmlDeviceGetName(handle)
 
 
+@lru_cache(maxsize=8)
+@with_nvml_context
+def get_physical_device_total_memory(device_id: int = 0) -> int:
+    handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
+    return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
+
+
 @with_nvml_context
 def warn_if_different_devices():
     device_ids: int = pynvml.nvmlDeviceGetCount()
@@ -107,6 +114,11 @@ def get_device_name(cls, device_id: int = 0) -> str:
         physical_device_id = device_id_to_physical_device_id(device_id)
         return get_physical_device_name(physical_device_id)
 
+    @classmethod
+    def get_device_total_memory(cls, device_id: int = 0) -> int:
+        physical_device_id = device_id_to_physical_device_id(device_id)
+        return get_physical_device_total_memory(physical_device_id)
+
     @classmethod
     @with_nvml_context
     def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py
index f1d787f59f4a0..9ab71516b3252 100644
--- a/vllm/platforms/interface.py
+++ b/vllm/platforms/interface.py
@@ -90,6 +90,12 @@ def has_device_capability(
 
     @classmethod
     def get_device_name(cls, device_id: int = 0) -> str:
+        """Get the name of a device."""
+        raise NotImplementedError
+
+    @classmethod
+    def get_device_total_memory(cls, device_id: int = 0) -> int:
+        """Get the total memory of a device in bytes."""
         raise NotImplementedError
 
     @classmethod
diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py
index b6a19eca01745..fd8afc92b0f28 100644
--- a/vllm/platforms/rocm.py
+++ b/vllm/platforms/rocm.py
@@ -29,3 +29,8 @@ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
     @lru_cache(maxsize=8)
     def get_device_name(cls, device_id: int = 0) -> str:
         return torch.cuda.get_device_name(device_id)
+
+    @classmethod
+    def get_device_total_memory(cls, device_id: int = 0) -> int:
+        device_props = torch.cuda.get_device_properties(device_id)
+        return device_props.total_memory
diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py
index b30bccb103af3..a35777f91cac9 100644
--- a/vllm/platforms/tpu.py
+++ b/vllm/platforms/tpu.py
@@ -10,6 +10,10 @@ class TpuPlatform(Platform):
     def get_device_name(cls, device_id: int = 0) -> str:
         raise NotImplementedError
 
+    @classmethod
+    def get_device_total_memory(cls, device_id: int = 0) -> int:
+        raise NotImplementedError
+
     @classmethod
     def inference_mode(cls):
         return torch.no_grad()
diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py
index e0f98d745b5e5..d00e0dca84fff 100644
--- a/vllm/platforms/xpu.py
+++ b/vllm/platforms/xpu.py
@@ -8,13 +8,15 @@ class XPUPlatform(Platform):
 
     @staticmethod
     def get_device_capability(device_id: int = 0) -> DeviceCapability:
-        return DeviceCapability(major=int(
-            torch.xpu.get_device_capability(device_id)['version'].split('.')
-            [0]),
-                                minor=int(
-                                    torch.xpu.get_device_capability(device_id)
-                                    ['version'].split('.')[1]))
+        major, minor, *_ = torch.xpu.get_device_capability(
+            device_id)['version'].split('.')
+        return DeviceCapability(major=int(major), minor=int(minor))
 
     @staticmethod
     def get_device_name(device_id: int = 0) -> str:
         return torch.xpu.get_device_name(device_id)
+
+    @classmethod
+    def get_device_total_memory(cls, device_id: int = 0) -> int:
+        device_props = torch.xpu.get_device_properties(device_id)
+        return device_props.total_memory
diff --git a/vllm/utils.py b/vllm/utils.py
index 20ebade5146bb..a025c3c40a434 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -119,6 +119,9 @@
 STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
 STR_INVALID_VAL: str = "INVALID"
 
+GB_bytes = 1_000_000_000
+"""The number of bytes in one gigabyte (GB)."""
+
 GiB_bytes = 1 << 30
 """The number of bytes in one gibibyte (GiB)."""