From e5c31dd38d8679ffddafa54283a9549227b6c3a6 Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Thu, 5 Dec 2024 13:08:38 +0200 Subject: [PATCH] Preserve _keep_empty in copying and encoding. --- Lib/test/test_urlparse.py | 123 ++++++++++++++++++++++++-------------- Lib/urllib/parse.py | 44 ++++++++++++-- 2 files changed, 118 insertions(+), 49 deletions(-) diff --git a/Lib/test/test_urlparse.py b/Lib/test/test_urlparse.py index f9c583710e43a7..626f4a59e7ba82 100644 --- a/Lib/test/test_urlparse.py +++ b/Lib/test/test_urlparse.py @@ -1,9 +1,10 @@ +import copy import functools import sys import unicodedata import unittest import urllib.parse -from urllib.parse import urlparse, urlsplit, urlunparse, urlunsplit +from urllib.parse import urldefrag, urlparse, urlsplit, urlunparse, urlunsplit RFC1808_BASE = "http://a/b/c/d;p?q#f" RFC2396_BASE = "http://a/b/c/d;p?q" @@ -391,14 +392,14 @@ def checkJoin(self, base, relurl, expected, *, relroundtrip=True): self.assertEqual(urllib.parse.urljoin(baseb, relurlb), expectedb) if relroundtrip: - relurl2 = urllib.parse.urlunsplit(urllib.parse.urlsplit(relurl)) + relurl2 = urlunsplit(urlsplit(relurl)) self.assertEqual(urllib.parse.urljoin(base, relurl2), expected) - relurlb2 = urllib.parse.urlunsplit(urllib.parse.urlsplit(relurlb)) + relurlb2 = urlunsplit(urlsplit(relurlb)) self.assertEqual(urllib.parse.urljoin(baseb, relurlb2), expectedb) - relurl3 = urllib.parse.urlunsplit(urllib.parse.urlsplit(relurl, allow_none=True)) + relurl3 = urlunsplit(urlsplit(relurl, allow_none=True)) self.assertEqual(urllib.parse.urljoin(base, relurl3), expected) - relurlb3 = urllib.parse.urlunsplit(urllib.parse.urlsplit(relurlb, allow_none=True)) + relurlb3 = urlunsplit(urlsplit(relurlb, allow_none=True)) self.assertEqual(urllib.parse.urljoin(baseb, relurlb3), expectedb) def test_unparse_parse(self): @@ -458,9 +459,9 @@ def test_RFC1808(self): def test_RFC2368(self): # Issue 11467: path that starts with a number is not parsed correctly - self.assertEqual(urllib.parse.urlparse('mailto:1337@example.org'), + self.assertEqual(urlparse('mailto:1337@example.org'), ('mailto', '', '1337@example.org', '', '', '')) - self.assertEqual(urllib.parse.urlparse('mailto:1337@example.org', allow_none=True), + self.assertEqual(urlparse('mailto:1337@example.org', allow_none=True), ('mailto', None, '1337@example.org', None, None, None)) def test_RFC2396(self): @@ -1119,50 +1120,50 @@ def test_withoutscheme(self, allow_none): # RFC 1808 specifies that netloc should start with //, urlparse expects # the same, otherwise it classifies the portion of url as path. none = None if allow_none else '' - self.assertEqual(urllib.parse.urlparse("path", allow_none=allow_none), + self.assertEqual(urlparse("path", allow_none=allow_none), (none, none, 'path', none, none, none)) - self.assertEqual(urllib.parse.urlparse("//www.python.org:80", allow_none=allow_none), + self.assertEqual(urlparse("//www.python.org:80", allow_none=allow_none), (none, 'www.python.org:80', '', none, none, none)) - self.assertEqual(urllib.parse.urlparse("http://www.python.org:80", allow_none=allow_none), + self.assertEqual(urlparse("http://www.python.org:80", allow_none=allow_none), ('http', 'www.python.org:80', '', none, none, none)) # Repeat for bytes input none = None if allow_none else b'' - self.assertEqual(urllib.parse.urlparse(b"path", allow_none=allow_none), + self.assertEqual(urlparse(b"path", allow_none=allow_none), (none, none, b'path', none, none, none)) - self.assertEqual(urllib.parse.urlparse(b"//www.python.org:80", allow_none=allow_none), + self.assertEqual(urlparse(b"//www.python.org:80", allow_none=allow_none), (none, b'www.python.org:80', b'', none, none, none)) - self.assertEqual(urllib.parse.urlparse(b"http://www.python.org:80", allow_none=allow_none), + self.assertEqual(urlparse(b"http://www.python.org:80", allow_none=allow_none), (b'http', b'www.python.org:80', b'', none, none, none)) @parametrise_allow_none def test_portseparator(self, allow_none): # Issue 754016 makes changes for port separator ':' from scheme separator none = None if allow_none else '' - self.assertEqual(urllib.parse.urlparse("http:80", allow_none=allow_none), + self.assertEqual(urlparse("http:80", allow_none=allow_none), ('http', none, '80', none, none, none)) - self.assertEqual(urllib.parse.urlparse("https:80", allow_none=allow_none), + self.assertEqual(urlparse("https:80", allow_none=allow_none), ('https', none, '80', none, none, none)) - self.assertEqual(urllib.parse.urlparse("path:80", allow_none=allow_none), + self.assertEqual(urlparse("path:80", allow_none=allow_none), ('path', none, '80', none, none, none)) - self.assertEqual(urllib.parse.urlparse("http:", allow_none=allow_none), + self.assertEqual(urlparse("http:", allow_none=allow_none), ('http', none, '', none, none, none)) - self.assertEqual(urllib.parse.urlparse("https:", allow_none=allow_none), + self.assertEqual(urlparse("https:", allow_none=allow_none), ('https', none, '', none, none, none)) - self.assertEqual(urllib.parse.urlparse("http://www.python.org:80", allow_none=allow_none), + self.assertEqual(urlparse("http://www.python.org:80", allow_none=allow_none), ('http', 'www.python.org:80', '', none, none, none)) # As usual, need to check bytes input as well none = None if allow_none else b'' - self.assertEqual(urllib.parse.urlparse(b"http:80", allow_none=allow_none), + self.assertEqual(urlparse(b"http:80", allow_none=allow_none), (b'http', none, b'80', none, none, none)) - self.assertEqual(urllib.parse.urlparse(b"https:80", allow_none=allow_none), + self.assertEqual(urlparse(b"https:80", allow_none=allow_none), (b'https', none, b'80', none, none, none)) - self.assertEqual(urllib.parse.urlparse(b"path:80", allow_none=allow_none), + self.assertEqual(urlparse(b"path:80", allow_none=allow_none), (b'path', none, b'80', none, none, none)) - self.assertEqual(urllib.parse.urlparse(b"http:", allow_none=allow_none), + self.assertEqual(urlparse(b"http:", allow_none=allow_none), (b'http', none, b'', none, none, none)) - self.assertEqual(urllib.parse.urlparse(b"https:", allow_none=allow_none), + self.assertEqual(urlparse(b"https:", allow_none=allow_none), (b'https', none, b'', none, none, none)) - self.assertEqual(urllib.parse.urlparse(b"http://www.python.org:80", allow_none=allow_none), + self.assertEqual(urlparse(b"http://www.python.org:80", allow_none=allow_none), (b'http', b'www.python.org:80', b'', none, none, none)) def test_usingsys(self): @@ -1173,24 +1174,24 @@ def test_usingsys(self): def test_anyscheme(self, allow_none): # Issue 7904: s3://foo.com/stuff has netloc "foo.com". none = None if allow_none else '' - self.assertEqual(urllib.parse.urlparse("s3://foo.com/stuff", allow_none=allow_none), + self.assertEqual(urlparse("s3://foo.com/stuff", allow_none=allow_none), ('s3', 'foo.com', '/stuff', none, none, none)) - self.assertEqual(urllib.parse.urlparse("x-newscheme://foo.com/stuff", allow_none=allow_none), + self.assertEqual(urlparse("x-newscheme://foo.com/stuff", allow_none=allow_none), ('x-newscheme', 'foo.com', '/stuff', none, none, none)) - self.assertEqual(urllib.parse.urlparse("x-newscheme://foo.com/stuff?query#fragment", allow_none=allow_none), + self.assertEqual(urlparse("x-newscheme://foo.com/stuff?query#fragment", allow_none=allow_none), ('x-newscheme', 'foo.com', '/stuff', none, 'query', 'fragment')) - self.assertEqual(urllib.parse.urlparse("x-newscheme://foo.com/stuff?query", allow_none=allow_none), + self.assertEqual(urlparse("x-newscheme://foo.com/stuff?query", allow_none=allow_none), ('x-newscheme', 'foo.com', '/stuff', none, 'query', none)) # And for bytes... none = None if allow_none else b'' - self.assertEqual(urllib.parse.urlparse(b"s3://foo.com/stuff", allow_none=allow_none), + self.assertEqual(urlparse(b"s3://foo.com/stuff", allow_none=allow_none), (b's3', b'foo.com', b'/stuff', none, none, none)) - self.assertEqual(urllib.parse.urlparse(b"x-newscheme://foo.com/stuff", allow_none=allow_none), + self.assertEqual(urlparse(b"x-newscheme://foo.com/stuff", allow_none=allow_none), (b'x-newscheme', b'foo.com', b'/stuff', none, none, none)) - self.assertEqual(urllib.parse.urlparse(b"x-newscheme://foo.com/stuff?query#fragment", allow_none=allow_none), + self.assertEqual(urlparse(b"x-newscheme://foo.com/stuff?query#fragment", allow_none=allow_none), (b'x-newscheme', b'foo.com', b'/stuff', none, b'query', b'fragment')) - self.assertEqual(urllib.parse.urlparse(b"x-newscheme://foo.com/stuff?query", allow_none=allow_none), + self.assertEqual(urlparse(b"x-newscheme://foo.com/stuff?query", allow_none=allow_none), (b'x-newscheme', b'foo.com', b'/stuff', none, b'query', none)) def test_default_scheme(self): @@ -1274,12 +1275,10 @@ def test_mixed_types_rejected(self): with self.assertRaisesRegex(TypeError, "Cannot mix str"): urllib.parse.urljoin(b"http://python.org", "http://python.org") - def _check_result_type(self, str_type): - num_args = len(str_type._fields) + def _check_result_type(self, str_type, str_args): bytes_type = str_type._encoded_counterpart self.assertIs(bytes_type._decoded_counterpart, str_type) - str_args = ('',) * num_args - bytes_args = (b'',) * num_args + bytes_args = tuple(self._encode(s) for s in str_args) str_result = str_type(*str_args) bytes_result = bytes_type(*bytes_args) encoding = 'ascii' @@ -1298,16 +1297,52 @@ def _check_result_type(self, str_type): self.assertEqual(str_result.encode(encoding), bytes_result) self.assertEqual(str_result.encode(encoding, errors), bytes_args) self.assertEqual(str_result.encode(encoding, errors), bytes_result) + for result in str_result, bytes_result: + self.assertEqual(copy.copy(result), result) + self.assertEqual(copy.deepcopy(result), result) + self.assertEqual(copy.replace(result), result) + self.assertEqual(result._replace(), result) def test_result_pairs(self): # Check encoding and decoding between result pairs - result_types = [ - urllib.parse.DefragResult, - urllib.parse.SplitResult, - urllib.parse.ParseResult, - ] - for result_type in result_types: - self._check_result_type(result_type) + self._check_result_type(urllib.parse.DefragResult, ('', '')) + self._check_result_type(urllib.parse.DefragResult, ('', None)) + self._check_result_type(urllib.parse.SplitResult, ('', '', '', '', '')) + self._check_result_type(urllib.parse.SplitResult, (None, None, '', None, None)) + self._check_result_type(urllib.parse.ParseResult, ('', '', '', '', '', '')) + self._check_result_type(urllib.parse.ParseResult, (None, None, '', None, None, None)) + + def test_result_encoding_decoding(self): + def check(str_result, bytes_result): + self.assertEqual(str_result.encode(), bytes_result) + self.assertEqual(str_result.encode().geturl(), bytes_result.geturl()) + self.assertEqual(bytes_result.decode(), str_result) + self.assertEqual(bytes_result.decode().geturl(), str_result.geturl()) + + url = 'http://example.com/?#' + burl = url.encode() + for func in urldefrag, urlsplit, urlparse: + check(func(url, allow_none=True), func(burl, allow_none=True)) + check(func(url), func(burl)) + + def test_result_copying(self): + def check(result): + self.assertEqual(copy.copy(result), result) + self.assertEqual(copy.copy(result).geturl(), result.geturl()) + self.assertEqual(copy.deepcopy(result), result) + self.assertEqual(copy.deepcopy(result).geturl(), result.geturl()) + self.assertEqual(copy.replace(result), result) + self.assertEqual(copy.replace(result).geturl(), result.geturl()) + self.assertEqual(result._replace(), result) + self.assertEqual(result._replace().geturl(), result.geturl()) + + url = 'http://example.com/?#' + burl = url.encode() + for func in urldefrag, urlsplit, urlparse: + check(func(url)) + check(func(url, allow_none=True)) + check(func(burl)) + check(func(burl, allow_none=True)) def test_parse_qs_encoding(self): result = urllib.parse.parse_qs("key=\u0141%E9", encoding="latin-1") diff --git a/Lib/urllib/parse.py b/Lib/urllib/parse.py index 8d2a05bd134135..86db54e9b14fe7 100644 --- a/Lib/urllib/parse.py +++ b/Lib/urllib/parse.py @@ -146,9 +146,14 @@ class _ResultMixinStr(object): __slots__ = () def encode(self, encoding='ascii', errors='strict'): - return self._encoded_counterpart(*(x.encode(encoding, errors) + result = self._encoded_counterpart(*(x.encode(encoding, errors) if x is not None else None for x in self)) + try: + result._keep_empty = self._keep_empty + except AttributeError: + pass + return result class _ResultMixinBytes(object): @@ -156,9 +161,14 @@ class _ResultMixinBytes(object): __slots__ = () def decode(self, encoding='ascii', errors='strict'): - return self._decoded_counterpart(*(x.decode(encoding, errors) + result = self._decoded_counterpart(*(x.decode(encoding, errors) if x is not None else None for x in self)) + try: + result._keep_empty = self._keep_empty + except AttributeError: + pass + return result class _NetlocResultMixinBase(object): @@ -270,7 +280,31 @@ def _hostinfo(self): _UNSPECIFIED = ['not specified'] _ALLOW_NONE_DEFAULT = False -class _DefragResultBase(namedtuple('_DefragResultBase', 'url fragment')): +class _ResultBase: + def __replace__(self, /, **kwargs): + result = super().__replace__(**kwargs) + try: + result._keep_empty = self._keep_empty + except AttributeError: + pass + return result + + def _replace(self, /, **kwargs): + result = super()._replace(**kwargs) + try: + result._keep_empty = self._keep_empty + except AttributeError: + pass + return result + + def __copy__(self): + return self + + def __deepcopy__(self, memo): + return self + + +class _DefragResultBase(_ResultBase, namedtuple('_DefragResultBase', 'url fragment')): def geturl(self): if self.fragment or (self.fragment is not None and getattr(self, '_keep_empty', _ALLOW_NONE_DEFAULT)): @@ -278,12 +312,12 @@ def geturl(self): else: return self.url -class _SplitResultBase(namedtuple( +class _SplitResultBase(_ResultBase, namedtuple( '_SplitResultBase', 'scheme netloc path query fragment')): def geturl(self): return urlunsplit(self) -class _ParseResultBase(namedtuple( +class _ParseResultBase(_ResultBase, namedtuple( '_ParseResultBase', 'scheme netloc path params query fragment')): def geturl(self): return urlunparse(self)