diff --git a/mypy_boto3_builder/structures/service_package.py b/mypy_boto3_builder/structures/service_package.py index 1288cc60..acd8311d 100644 --- a/mypy_boto3_builder/structures/service_package.py +++ b/mypy_boto3_builder/structures/service_package.py @@ -5,6 +5,7 @@ """ from collections.abc import Iterable, Iterator +from itertools import chain from typing import Literal from mypy_boto3_builder.enums.service_module_name import ServiceModuleName @@ -80,7 +81,7 @@ def extract_literals(self) -> list[TypeLiteral]: Extract literals from children. """ type_literals: set[TypeLiteral] = set() - for type_annotation in [*self.iterate_types(), *self.type_defs]: + for type_annotation in chain(self.iterate_types(), self.type_defs): if isinstance(type_annotation, TypeDefSortable): type_literals.update(type_annotation.get_children_literals()) if isinstance(type_annotation, TypeLiteral): diff --git a/tests/structures/test_service_package.py b/tests/structures/test_service_package.py index 4e8cf5be..ce5a812d 100644 --- a/tests/structures/test_service_package.py +++ b/tests/structures/test_service_package.py @@ -5,7 +5,9 @@ from mypy_boto3_builder.exceptions import StructureError from mypy_boto3_builder.package_data import Boto3StubsPackageData from mypy_boto3_builder.service_name import ServiceNameCatalog +from mypy_boto3_builder.structures.argument import Argument from mypy_boto3_builder.structures.client import Client +from mypy_boto3_builder.structures.method import Method from mypy_boto3_builder.structures.paginator import Paginator from mypy_boto3_builder.structures.service_package import ServicePackage from mypy_boto3_builder.structures.service_resource import ServiceResource @@ -19,11 +21,15 @@ class TestServicePackage: def setup_method(self) -> None: service_name = ServiceNameCatalog.s3 + client = Client("Client", service_name) + client.methods.append( + Method("method", [Argument("self", None)], TypeLiteral("NewLiteral", ["value"])) + ) self.service_package = ServicePackage( data=Boto3StubsPackageData(), service_name=service_name, version="1.2.3", - client=Client("Client", service_name), + client=client, service_resource=ServiceResource("ServiceResource", service_name), waiters=[Waiter("waiter", "waiter", "waiter", service_name)], paginators=[Paginator("Paginator", "Paginator", "paginate", service_name)], @@ -44,7 +50,9 @@ def test_client(self) -> None: _ = self.service_package.client def test_extract_literals(self) -> None: - assert self.service_package.extract_literals() == [] + literals = self.service_package.extract_literals() + assert len(literals) == 1 + assert literals[0].name == "NewLiteral" def test_extract_type_defs(self) -> None: assert self.service_package.extract_type_defs() == set() @@ -61,6 +69,13 @@ def test_get_client_required_import_records(self) -> None: "from botocore.client import BaseClient", "from botocore.errorfactory import BaseClientExceptions", "from typing import Any", + ( + "if sys.version_info >= (3, 12):" + "\n from typing import Literal" + "\nelse:" + "\n from typing_extensions import Literal" + ), + "import sys", ] def test_get_service_resource_required_import_records(self) -> None: