Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc authored Feb 10, 2025
2 parents fe28cd9 + 924f1c7 commit 20dd133
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 26 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,5 @@ markers = [
"bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests",
"generate: marks tests that use the GenerationTesterMixin"
]
log_cli = 1
log_cli_level = "WARNING"
3 changes: 1 addition & 2 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,8 +785,7 @@ def validate(self, is_init=False):
for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"):
if getattr(self, arg_name) is not None:
logger.warning_once(
no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)),
UserWarning,
no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name))
)

# 6. check watermarking arguments
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/image_processing_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache, partial
from typing import Any, Dict, Iterable, List, Optional, Tuple, TypedDict, Union

Expand Down Expand Up @@ -497,8 +496,10 @@ def _prepare_input_images(
input_data_format=input_data_format,
device=device,
)
with ThreadPoolExecutor() as executor:
processed_images = list(executor.map(process_image_fn, images))
# todo: yoni - check if we can parallelize this efficiently
processed_images = []
for image in images:
processed_images.append(process_image_fn(image))

return processed_images

Expand Down
75 changes: 62 additions & 13 deletions src/transformers/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,45 +393,71 @@ def _get_wsd_scheduler_lambda(
num_warmup_steps: int,
num_stable_steps: int,
num_decay_steps: int,
num_cycles: float,
warmup_type: str,
decay_type: str,
min_lr_ratio: float,
num_cycles: float,
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step) / float(max(1, num_warmup_steps))
if warmup_type == "linear":
factor = progress
elif warmup_type == "cosine":
factor = 0.5 * (1.0 - math.cos(math.pi * progress))
elif warmup_type == "1-sqrt":
factor = 1.0 - math.sqrt(1.0 - progress)
factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio
return max(0.0, factor)

if current_step < num_warmup_steps + num_stable_steps:
return 1.0

if current_step < num_warmup_steps + num_stable_steps + num_decay_steps:
progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps))
value = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
return (1.0 - min_lr_ratio) * value + min_lr_ratio
if decay_type == "linear":
factor = 1.0 - progress
elif decay_type == "cosine":
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
elif decay_type == "1-sqrt":
factor = 1.0 - math.sqrt(progress)
factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio
return max(0.0, factor)
return min_lr_ratio


def get_wsd_schedule(
optimizer: Optimizer,
num_warmup_steps: int,
num_stable_steps: int,
num_decay_steps: int,
num_training_steps: Optional[int] = None,
num_stable_steps: Optional[int] = None,
warmup_type: str = "linear",
decay_type: str = "cosine",
min_lr_ratio: float = 0,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that has three stages:
1. linear increase from 0 to initial lr.
2. constant lr (equal to initial lr).
3. decrease following the values of the cosine function between the initial lr set in the optimizer to
a fraction of initial lr.
1. warmup: increase from min_lr_ratio times the initial learning rate to the initial learning rate following a warmup_type.
2. stable: constant learning rate.
3. decay: decrease from the initial learning rate to min_lr_ratio times the initial learning rate following a decay_type.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_stable_steps (`int`):
The number of steps for the stable phase.
num_decay_steps (`int`):
The number of steps for the cosine annealing phase.
The number of steps for the decay phase.
num_training_steps (`int`, *optional*):
The total number of training steps. This is the sum of the warmup, stable and decay steps. If `num_stable_steps` is not provided, the stable phase will be `num_training_steps - num_warmup_steps - num_decay_steps`.
num_stable_steps (`int`, *optional*):
The number of steps for the stable phase. Please ensure that `num_warmup_steps + num_stable_steps + num_decay_steps` equals `num_training_steps`, otherwise the other steps will default to the minimum learning rate.
warmup_type (`str`, *optional*, defaults to "linear"):
The type of warmup to use. Can be 'linear', 'cosine' or '1-sqrt'.
decay_type (`str`, *optional*, defaults to "cosine"):
The type of decay to use. Can be 'linear', 'cosine' or '1-sqrt'.
min_lr_ratio (`float`, *optional*, defaults to 0):
The minimum learning rate as a ratio of the initial learning rate.
num_cycles (`float`, *optional*, defaults to 0.5):
Expand All @@ -443,11 +469,29 @@ def get_wsd_schedule(
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""

if num_training_steps is None and num_stable_steps is None:
raise ValueError("Either num_training_steps or num_stable_steps must be specified.")

if num_training_steps is not None and num_stable_steps is not None:
warnings.warn("Both num_training_steps and num_stable_steps are specified. num_stable_steps will be used.")

if warmup_type not in ["linear", "cosine", "1-sqrt"]:
raise ValueError(f"Unknown warmup type: {warmup_type}, expected 'linear', 'cosine' or '1-sqrt'")

if decay_type not in ["linear", "cosine", "1-sqrt"]:
raise ValueError(f"Unknown decay type: {decay_type}, expected 'linear', 'cosine' or '1-sqrt'")

if num_stable_steps is None:
num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps

lr_lambda = partial(
_get_wsd_scheduler_lambda,
num_warmup_steps=num_warmup_steps,
num_stable_steps=num_stable_steps,
num_decay_steps=num_decay_steps,
warmup_type=warmup_type,
decay_type=decay_type,
min_lr_ratio=min_lr_ratio,
num_cycles=num_cycles,
)
Expand Down Expand Up @@ -541,7 +585,12 @@ def scheduler_hook(param):
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)

if name == SchedulerType.WARMUP_STABLE_DECAY:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **scheduler_specific_kwargs)
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
**scheduler_specific_kwargs,
)

# All other schedulers require `num_training_steps`
if num_training_steps is None:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def _configure_library_root_logger() -> None:
formatter = logging.Formatter("[%(levelname)s|%(pathname)s:%(lineno)s] %(asctime)s >> %(message)s")
_default_handler.setFormatter(formatter)

library_root_logger.propagate = False
is_ci = os.getenv("CI") is not None and os.getenv("CI").upper() in {"1", "ON", "YES", "TRUE"}
library_root_logger.propagate = True if is_ci else False


def _reset_library_root_logger() -> None:
Expand Down
11 changes: 8 additions & 3 deletions tests/generation/test_streamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import unittest
from queue import Empty
from threading import Thread
from unittest.mock import patch

import pytest

Expand All @@ -27,6 +28,7 @@
is_torch_available,
)
from transformers.testing_utils import CaptureStdout, require_torch, torch_device
from transformers.utils.logging import _get_library_root_logger

from ..test_modeling_common import ids_tensor

Expand Down Expand Up @@ -102,9 +104,12 @@ def test_text_streamer_decode_kwargs(self):
model.config.eos_token_id = -1

input_ids = torch.ones((1, 5), device=torch_device).long() * model.config.bos_token_id
with CaptureStdout() as cs:
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
model.generate(input_ids, max_new_tokens=1, do_sample=False, streamer=streamer)

root = _get_library_root_logger()
with patch.object(root, "propagate", False):
with CaptureStdout() as cs:
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
model.generate(input_ids, max_new_tokens=1, do_sample=False, streamer=streamer)

# The prompt contains a special token, so the streamer should not print it. As such, the output text, when
# re-tokenized, must only contain one token
Expand Down
28 changes: 24 additions & 4 deletions tests/optimization/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ def test_schedulers(self):
[0.0, 5.0, 10.0, 8.165, 7.071, 6.325, 5.774, 5.345, 5.0, 4.714],
),
get_wsd_schedule: (
{"num_warmup_steps": 2, "num_stable_steps": 2, "num_decay_steps": 3, "min_lr_ratio": 0.1},
[0.0, 5.0, 10.0, 10.0, 10.0, 7.75, 3.25, 1.0, 1.0, 1.0],
{**common_kwargs, "num_decay_steps": 2, "min_lr_ratio": 0.0},
[0.0, 5.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 5.0],
),
}

Expand Down Expand Up @@ -183,14 +183,34 @@ def test_get_scheduler(self):
"name": "warmup_stable_decay",
"optimizer": self.optimizer,
"num_warmup_steps": 2,
"scheduler_specific_kwargs": {"num_stable_steps": 1, "num_decay_steps": 3},
"num_training_steps": 10,
"scheduler_specific_kwargs": {
"num_decay_steps": 2,
"warmup_type": "linear",
"decay_type": "linear",
},
},
{
"name": "warmup_stable_decay",
"optimizer": self.optimizer,
"num_warmup_steps": 2,
"num_training_steps": 10,
"scheduler_specific_kwargs": {
"num_decay_steps": 2,
"warmup_type": "cosine",
"decay_type": "cosine",
},
},
{
"name": "warmup_stable_decay",
"optimizer": self.optimizer,
"num_warmup_steps": 2,
"num_training_steps": 10,
"scheduler_specific_kwargs": {"num_stable_steps": 1, "num_decay_steps": 3},
"scheduler_specific_kwargs": {
"num_decay_steps": 2,
"warmup_type": "1-sqrt",
"decay_type": "1-sqrt",
},
},
{"name": "cosine", "optimizer": self.optimizer, "num_warmup_steps": 2, "num_training_steps": 10},
]
Expand Down

0 comments on commit 20dd133

Please sign in to comment.