From 86bf8bf94abcc5045c2b5584ac8123c6d8da64bb Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Fri, 5 Jul 2024 16:03:00 +0200 Subject: [PATCH 1/3] Add `strtobool` custom implementation from `distutils` --- trl/env_utils.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 trl/env_utils.py diff --git a/trl/env_utils.py b/trl/env_utils.py new file mode 100644 index 0000000000..455e0dbfca --- /dev/null +++ b/trl/env_utils.py @@ -0,0 +1,36 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Function `strtobool` copied and adapted from `distutils` (as deprected +# in Python 3.10). +# Reference: https://github.com/python/cpython/blob/48f9d3e3faec5faaa4f7c9849fecd27eae4da213/Lib/distutils/util.py#L308-L321 + + +def strtobool(val: str) -> bool: + """Convert a string representation of truth to True or False booleans. + + True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values + are 'n', 'no', 'f', 'false', 'off', and '0'. + + Raises: + ValueError: if 'val' is anything else. + """ + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1"): + return True + if val in ("n", "no", "f", "false", "off", "0"): + return False + raise ValueError( + f"Invalid truth value, it should be a string but {val} was provided instead." + ) From 806ba2fb2e9ed4cea4c45a3903f5937dcb5a6c3a Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Fri, 5 Jul 2024 16:03:23 +0200 Subject: [PATCH 2/3] Fix `TRL_USE_RICH` handling via `strtobool` --- examples/scripts/dpo.py | 5 +++-- examples/scripts/sft.py | 5 +++-- examples/scripts/vsft_llava.py | 5 +++-- trl/commands/cli.py | 5 +++-- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py index 62df16f06d..4c50612b28 100644 --- a/examples/scripts/dpo.py +++ b/examples/scripts/dpo.py @@ -55,9 +55,10 @@ import os from contextlib import nullcontext -TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False) - from trl.commands.cli_utils import DPOScriptArguments, init_zero_verbose, TrlParser +from trl.env_utils import strtobool + +TRL_USE_RICH = strtobool(os.getenv("TRL_USE_RICH", "0")) if TRL_USE_RICH: init_zero_verbose() diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py index 80ad65d96b..d98b1c671f 100644 --- a/examples/scripts/sft.py +++ b/examples/scripts/sft.py @@ -49,9 +49,10 @@ import os from contextlib import nullcontext -TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False) - from trl.commands.cli_utils import init_zero_verbose, SFTScriptArguments, TrlParser +from trl.env_utils import strtobool + +TRL_USE_RICH = strtobool(os.getenv("TRL_USE_RICH", "0")) if TRL_USE_RICH: init_zero_verbose() diff --git a/examples/scripts/vsft_llava.py b/examples/scripts/vsft_llava.py index 85cb98d5f3..32e9e0b804 100644 --- a/examples/scripts/vsft_llava.py +++ b/examples/scripts/vsft_llava.py @@ -68,9 +68,10 @@ import os from contextlib import nullcontext -TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False) - from trl.commands.cli_utils import init_zero_verbose, SFTScriptArguments, TrlParser +from trl.env_utils import strtobool + +TRL_USE_RICH = strtobool(os.getenv("TRL_USE_RICH", "0")) if TRL_USE_RICH: init_zero_verbose() diff --git a/trl/commands/cli.py b/trl/commands/cli.py index 46b761473a..f5695c233d 100644 --- a/trl/commands/cli.py +++ b/trl/commands/cli.py @@ -41,8 +41,9 @@ def main(): trl_examples_dir = os.path.dirname(__file__) - # Force-use rich - os.environ["TRL_USE_RICH"] = "1" + # Force-use rich if the `TRL_USE_RICH` env var is not set + if "TRL_USE_RICH" not in os.environ: + os.environ["TRL_USE_RICH"] = "1" if command_name == "chat": command = f""" From 5d8609b26bf7642dddab6d77daa937d94d5e251f Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Fri, 5 Jul 2024 16:05:39 +0200 Subject: [PATCH 3/3] Run `make precommit` --- trl/env_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/trl/env_utils.py b/trl/env_utils.py index 455e0dbfca..64e98199e0 100644 --- a/trl/env_utils.py +++ b/trl/env_utils.py @@ -31,6 +31,4 @@ def strtobool(val: str) -> bool: return True if val in ("n", "no", "f", "false", "off", "0"): return False - raise ValueError( - f"Invalid truth value, it should be a string but {val} was provided instead." - ) + raise ValueError(f"Invalid truth value, it should be a string but {val} was provided instead.")