diff --git a/pathvalidate/_common.py b/pathvalidate/_common.py index 950640e..066e5b1 100644 --- a/pathvalidate/_common.py +++ b/pathvalidate/_common.py @@ -138,3 +138,10 @@ def normalize_platform(name: Optional[PlatformType]) -> Platform: def findall_to_str(match: List[Any]) -> str: return ", ".join([repr(text) for text in match]) + + +def truncate_str(text: str, encoding: str, max_bytes: int) -> str: + str_bytes = text.encode(encoding) + str_bytes = str_bytes[:max_bytes] + # last char might be malformed, ignore it + return str_bytes.decode(encoding, "ignore") diff --git a/pathvalidate/_filename.py b/pathvalidate/_filename.py index bd3c421..1b2168f 100644 --- a/pathvalidate/_filename.py +++ b/pathvalidate/_filename.py @@ -11,7 +11,7 @@ from typing import Optional, Pattern, Sequence, Tuple from ._base import AbstractSanitizer, AbstractValidator, BaseFile, BaseValidator -from ._common import findall_to_str, to_str, validate_pathtype +from ._common import findall_to_str, to_str, truncate_str, validate_pathtype from ._const import DEFAULT_MIN_LEN, INVALID_CHAR_ERR_MSG_TMPL, Platform from ._types import PathType, PlatformType from .error import ErrorAttrKey, ErrorReason, InvalidCharError, ValidationError @@ -75,7 +75,7 @@ def sanitize(self, value: PathType, replacement_text: str = "") -> PathType: raise sanitized_filename = self._sanitize_regexp.sub(replacement_text, str(value)) - sanitized_filename = sanitized_filename[: self.max_len] + sanitized_filename = truncate_str(sanitized_filename, self._fs_encoding, self.max_len) try: self._validator.validate(sanitized_filename) diff --git a/test/test_filename.py b/test/test_filename.py index 1383e37..176b647 100644 --- a/test/test_filename.py +++ b/test/test_filename.py @@ -756,3 +756,20 @@ def test_exception_type(self, value, expected): with pytest.raises(expected): sanitize_filename(value) assert not is_valid_filename(value) + + @pytest.mark.parametrize( + ["value", "platform", "fs_encoding", "max_len", "expected"], + [ + ["あ" * 85, "universal", "utf-8", 255, "あ" * 85], + ["あ" * 86, "universal", "utf-8", 255, "あ" * 85], + ["あ" * 126, "universal", "utf-16", 255, "あ" * 126], + ["あ" * 127, "universal", "utf-16", 255, "あ" * 126], + ], + ) + def test_max_len_fs_encoding(self, value, platform, fs_encoding, max_len, expected): + kwargs = { + "platform": platform, + "max_len": max_len, + "fs_encoding": fs_encoding, + } + assert sanitize_filename(value, **kwargs) == expected