From 3897f2caf81470b04dce8343bfc11e4ef851d31a Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 10 Feb 2025 13:36:20 +0100 Subject: [PATCH] Enable pytest live log and show warning logs on GitHub Actions CI runs (#35912) * fix * remove * fix --------- Co-authored-by: ydshieh --- pyproject.toml | 2 ++ src/transformers/generation/configuration_utils.py | 3 +-- src/transformers/utils/logging.py | 3 ++- tests/generation/test_streamers.py | 11 ++++++++--- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bf78e0174394..79a6d9e70ae8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,3 +52,5 @@ markers = [ "bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests", "generate: marks tests that use the GenerationTesterMixin" ] +log_cli = 1 +log_cli_level = "WARNING" \ No newline at end of file diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index a0e96c31cb59..3b01607fd048 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -785,8 +785,7 @@ def validate(self, is_init=False): for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"): if getattr(self, arg_name) is not None: logger.warning_once( - no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)), - UserWarning, + no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)) ) # 6. check watermarking arguments diff --git a/src/transformers/utils/logging.py b/src/transformers/utils/logging.py index a304e9d29f46..67f70b96eddc 100644 --- a/src/transformers/utils/logging.py +++ b/src/transformers/utils/logging.py @@ -101,7 +101,8 @@ def _configure_library_root_logger() -> None: formatter = logging.Formatter("[%(levelname)s|%(pathname)s:%(lineno)s] %(asctime)s >> %(message)s") _default_handler.setFormatter(formatter) - library_root_logger.propagate = False + is_ci = os.getenv("CI") is not None and os.getenv("CI").upper() in {"1", "ON", "YES", "TRUE"} + library_root_logger.propagate = True if is_ci else False def _reset_library_root_logger() -> None: diff --git a/tests/generation/test_streamers.py b/tests/generation/test_streamers.py index be8c37334d02..2b17b6a7c43a 100644 --- a/tests/generation/test_streamers.py +++ b/tests/generation/test_streamers.py @@ -16,6 +16,7 @@ import unittest from queue import Empty from threading import Thread +from unittest.mock import patch import pytest @@ -27,6 +28,7 @@ is_torch_available, ) from transformers.testing_utils import CaptureStdout, require_torch, torch_device +from transformers.utils.logging import _get_library_root_logger from ..test_modeling_common import ids_tensor @@ -102,9 +104,12 @@ def test_text_streamer_decode_kwargs(self): model.config.eos_token_id = -1 input_ids = torch.ones((1, 5), device=torch_device).long() * model.config.bos_token_id - with CaptureStdout() as cs: - streamer = TextStreamer(tokenizer, skip_special_tokens=True) - model.generate(input_ids, max_new_tokens=1, do_sample=False, streamer=streamer) + + root = _get_library_root_logger() + with patch.object(root, "propagate", False): + with CaptureStdout() as cs: + streamer = TextStreamer(tokenizer, skip_special_tokens=True) + model.generate(input_ids, max_new_tokens=1, do_sample=False, streamer=streamer) # The prompt contains a special token, so the streamer should not print it. As such, the output text, when # re-tokenized, must only contain one token