diff --git a/src/ormar_postgres_extensions/__init__.py b/src/ormar_postgres_extensions/__init__.py index f34eab7..11a0843 100644 --- a/src/ormar_postgres_extensions/__init__.py +++ b/src/ormar_postgres_extensions/__init__.py @@ -1,5 +1,6 @@ from .fields import ( # noqa: F401 ARRAY, JSONB, + MACADDR, UUID, ) diff --git a/src/ormar_postgres_extensions/fields/__init__.py b/src/ormar_postgres_extensions/fields/__init__.py index 3e4fceb..229b144 100644 --- a/src/ormar_postgres_extensions/fields/__init__.py +++ b/src/ormar_postgres_extensions/fields/__init__.py @@ -1,3 +1,4 @@ from .array import ARRAY # noqa: F401 from .jsonb import JSONB # noqa: F401 +from .macaddr import MACADDR # noqa: F401 from .uuid import UUID # noqa: F401 diff --git a/src/ormar_postgres_extensions/fields/macaddr.py b/src/ormar_postgres_extensions/fields/macaddr.py new file mode 100644 index 0000000..f6dc2d6 --- /dev/null +++ b/src/ormar_postgres_extensions/fields/macaddr.py @@ -0,0 +1,13 @@ +from typing import Any + +import ormar +from sqlalchemy.dialects import postgresql + + +class MACADDR(ormar.fields.model_fields.ModelFieldFactory, str): + _type = str + + @classmethod + def get_column_type(cls, **kwargs: Any) -> postgresql.MACADDR: + # Tell Ormar that this column should be a postgres macaddr type + return postgresql.MACADDR() diff --git a/tests/fields/test_macaddr.py b/tests/fields/test_macaddr.py new file mode 100644 index 0000000..af35ae0 --- /dev/null +++ b/tests/fields/test_macaddr.py @@ -0,0 +1,64 @@ +from typing import Optional + +import ormar +import pytest + +import ormar_postgres_extensions as ormar_pg_ext +from tests.database import ( + database, + metadata, +) + + +class MacAddrTestModel(ormar.Model): + class Meta: + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + addr: str = ormar_pg_ext.MACADDR() + + +class NullableMacAddrTestModel(ormar.Model): + class Meta: + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + addr: Optional[str] = ormar_pg_ext.MACADDR(nullable=True) + + +@pytest.mark.asyncio +async def test_create_model_with_macaddr_specified(db): + created = await MacAddrTestModel(addr="08:00:2b:01:02:03").save() + assert str(created.addr) == "08:00:2b:01:02:03" + assert isinstance(created.addr, str) + + # Confirm the model got saved to the DB by querying it back + found = await MacAddrTestModel.objects.get() + assert found.addr == created.addr + assert isinstance(found.addr, str) + + +@pytest.mark.asyncio +async def test_get_model_by_macaddr(db): + created = await MacAddrTestModel(addr="08:00:2b:01:02:03").save() + + found = await MacAddrTestModel.objects.filter(addr="08:00:2b:01:02:03").all() + assert len(found) == 1 + assert found[0] == created + + +@pytest.mark.asyncio +async def test_create_model_with_nullable_macaddr(db): + created = await NullableMacAddrTestModel().save() + assert created.addr is None + + +@pytest.mark.asyncio +async def test_get_model_with_nullable_macaddr(db): + created = await NullableMacAddrTestModel().save() + + # Ensure querying a model with a null UUID works + found = await NullableMacAddrTestModel.objects.get() + assert found == created