diff --git a/python/dev/Dockerfile b/python/dev/Dockerfile index 21f732e77445..a4099d34946e 100644 --- a/python/dev/Dockerfile +++ b/python/dev/Dockerfile @@ -40,6 +40,7 @@ ENV SPARK_VERSION=3.4.1 ENV ICEBERG_SPARK_RUNTIME_VERSION=3.4_2.12 ENV ICEBERG_VERSION=1.3.1 ENV AWS_SDK_VERSION=2.20.18 +ENV PYICEBERG_VERSION=0.4.0 RUN curl --retry 3 -s -C - https://dlcdn.apache.org/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-hadoop3.tgz -o spark-${SPARK_VERSION}-bin-hadoop3.tgz \ && tar xzf spark-${SPARK_VERSION}-bin-hadoop3.tgz --directory /opt/spark --strip-components 1 \ @@ -65,6 +66,8 @@ RUN chmod u+x /opt/spark/sbin/* && \ RUN pip3 install -q ipython +RUN pip3 install "pyiceberg[s3fs]==${PYICEBERG_VERSION}" + COPY entrypoint.sh . COPY provision.py . diff --git a/python/dev/provision.py b/python/dev/provision.py index f62687b746af..37c5fec97339 100644 --- a/python/dev/provision.py +++ b/python/dev/provision.py @@ -18,6 +18,10 @@ from pyspark.sql import SparkSession from pyspark.sql.functions import current_date, date_add, expr +from pyiceberg.catalog import load_catalog +from pyiceberg.schema import Schema +from pyiceberg.types import FixedType, NestedField, UUIDType + spark = SparkSession.builder.getOrCreate() spark.sql( @@ -26,6 +30,35 @@ """ ) +schema = Schema( + NestedField(field_id=1, name="uuid_col", field_type=UUIDType(), required=False), + NestedField(field_id=2, name="fixed_col", field_type=FixedType(25), required=False), +) + +catalog = load_catalog( + "local", + **{ + "type": "rest", + "uri": "http://rest:8181", + "s3.endpoint": "http://minio:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + }, +) + +catalog.create_table(identifier="default.test_uuid_and_fixed_unpartitioned", schema=schema) + +spark.sql( + """ + INSERT INTO default.test_uuid_and_fixed_unpartitioned VALUES + ('102cb62f-e6f8-4eb0-9973-d9b012ff0967', CAST('1234567890123456789012345' AS BINARY)), + ('ec33e4b2-a834-4cc3-8c4a-a1d3bfc2f226', CAST('1231231231231231231231231' AS BINARY)), + ('639cccce-c9d2-494a-a78c-278ab234f024', CAST('12345678901234567ass12345' AS BINARY)), + ('c1b0d8e0-0b0e-4b1e-9b0a-0e0b0d0c0a0b', CAST('asdasasdads12312312312111' AS BINARY)), + ('923dae77-83d6-47cd-b4b0-d383e64ee57e', CAST('qweeqwwqq1231231231231111' AS BINARY)); + """ +) + spark.sql( """ CREATE OR REPLACE TABLE default.test_null_nan diff --git a/python/pyiceberg/conversions.py b/python/pyiceberg/conversions.py index 8f155fce3d95..0b3f36fba9d6 100644 --- a/python/pyiceberg/conversions.py +++ b/python/pyiceberg/conversions.py @@ -65,7 +65,6 @@ _LONG_STRUCT = Struct("QQ") def handle_none(func: Callable) -> Callable: # type: ignore @@ -228,8 +227,10 @@ def _(_: StringType, value: str) -> bytes: @to_bytes.register(UUIDType) -def _(_: UUIDType, value: uuid.UUID) -> bytes: - return _UUID_STRUCT.pack((value.int >> 64) & 0xFFFFFFFFFFFFFFFF, value.int & 0xFFFFFFFFFFFFFFFF) +def _(_: UUIDType, value: Union[uuid.UUID, bytes]) -> bytes: + if isinstance(value, bytes): + return value + return value.bytes @to_bytes.register(BinaryType) @@ -310,14 +311,9 @@ def _(_: StringType, b: bytes) -> str: return bytes(b).decode("utf-8") -@from_bytes.register(UUIDType) -def _(_: UUIDType, b: bytes) -> uuid.UUID: - unpacked_bytes = _UUID_STRUCT.unpack(b) - return uuid.UUID(int=unpacked_bytes[0] << 64 | unpacked_bytes[1]) - - @from_bytes.register(BinaryType) @from_bytes.register(FixedType) +@from_bytes.register(UUIDType) def _(_: PrimitiveType, b: bytes) -> bytes: return b diff --git a/python/pyiceberg/expressions/literals.py b/python/pyiceberg/expressions/literals.py index f89d0c8331b6..b24f3932d53b 100644 --- a/python/pyiceberg/expressions/literals.py +++ b/python/pyiceberg/expressions/literals.py @@ -57,6 +57,8 @@ from pyiceberg.utils.decimal import decimal_to_unscaled, unscaled_to_decimal from pyiceberg.utils.singleton import Singleton +UUID_BYTES_LENGTH = 16 + class Literal(Generic[L], ABC): """Literal which has a value and can be converted between types.""" @@ -139,7 +141,7 @@ def literal(value: L) -> Literal[L]: elif isinstance(value, str): return StringLiteral(value) elif isinstance(value, UUID): - return UUIDLiteral(value) + return UUIDLiteral(value.bytes) # type: ignore elif isinstance(value, bytes): return BinaryLiteral(value) elif isinstance(value, Decimal): @@ -571,8 +573,8 @@ def _(self, _: TimestamptzType) -> Literal[int]: return TimestampLiteral(timestamptz_to_micros(self.value)) @to.register(UUIDType) - def _(self, _: UUIDType) -> Literal[UUID]: - return UUIDLiteral(UUID(self.value)) + def _(self, _: UUIDType) -> Literal[bytes]: + return UUIDLiteral(UUID(self.value).bytes) @to.register(DecimalType) def _(self, type_var: DecimalType) -> Literal[Decimal]: @@ -596,16 +598,16 @@ def __repr__(self) -> str: return f"literal({repr(self.value)})" -class UUIDLiteral(Literal[UUID]): - def __init__(self, value: UUID) -> None: - super().__init__(value, UUID) +class UUIDLiteral(Literal[bytes]): + def __init__(self, value: bytes) -> None: + super().__init__(value, bytes) @singledispatchmethod def to(self, type_var: IcebergType) -> Literal: # type: ignore raise TypeError(f"Cannot convert UUIDLiteral into {type_var}") @to.register(UUIDType) - def _(self, _: UUIDType) -> Literal[UUID]: + def _(self, _: UUIDType) -> Literal[bytes]: return self @@ -630,6 +632,15 @@ def _(self, type_var: FixedType) -> Literal[bytes]: def _(self, _: BinaryType) -> Literal[bytes]: return BinaryLiteral(self.value) + @to.register(UUIDType) + def _(self, type_var: UUIDType) -> Literal[bytes]: + if len(self.value) == UUID_BYTES_LENGTH: + return UUIDLiteral(self.value) + else: + raise TypeError( + f"Could not convert {self.value!r} into a {type_var}, lengths differ {len(self.value)} <> {UUID_BYTES_LENGTH}" + ) + class BinaryLiteral(Literal[bytes]): def __init__(self, value: bytes) -> None: @@ -651,3 +662,12 @@ def _(self, type_var: FixedType) -> Literal[bytes]: raise TypeError( f"Cannot convert BinaryLiteral into {type_var}, different length: {len(type_var)} <> {len(self.value)}" ) + + @to.register(UUIDType) + def _(self, type_var: UUIDType) -> Literal[bytes]: + if len(self.value) == UUID_BYTES_LENGTH: + return UUIDLiteral(self.value) + else: + raise TypeError( + f"Cannot convert BinaryLiteral into {type_var}, different length: {UUID_BYTES_LENGTH} <> {len(self.value)}" + ) diff --git a/python/pyiceberg/io/pyarrow.py b/python/pyiceberg/io/pyarrow.py index 425507200321..fba16f99929d 100644 --- a/python/pyiceberg/io/pyarrow.py +++ b/python/pyiceberg/io/pyarrow.py @@ -451,7 +451,7 @@ def visit_binary(self, _: BinaryType) -> pa.DataType: def _convert_scalar(value: Any, iceberg_type: IcebergType) -> pa.scalar: if not isinstance(iceberg_type, PrimitiveType): raise ValueError(f"Expected primitive type, got: {iceberg_type}") - return pa.scalar(value).cast(schema_to_pyarrow(iceberg_type)) + return pa.scalar(value=value, type=schema_to_pyarrow(iceberg_type)) class _ConvertToArrowExpression(BoundBooleanExpressionVisitor[pc.Expression]): diff --git a/python/pyiceberg/transforms.py b/python/pyiceberg/transforms.py index 4b67f6687375..3e90c911d139 100644 --- a/python/pyiceberg/transforms.py +++ b/python/pyiceberg/transforms.py @@ -28,6 +28,7 @@ ) from typing import Literal as LiteralType from typing import Optional, TypeVar +from uuid import UUID import mmh3 from pydantic import Field, PositiveInt, PrivateAttr @@ -269,13 +270,9 @@ def hash_func(v: Any) -> int: elif source_type == UUIDType: def hash_func(v: Any) -> int: - return mmh3.hash( - struct.pack( - ">QQ", - (v.int >> 64) & 0xFFFFFFFFFFFFFFFF, - v.int & 0xFFFFFFFFFFFFFFFF, - ) - ) + if isinstance(v, UUID): + return mmh3.hash(v.bytes) + return mmh3.hash(v) else: raise ValueError(f"Unknown type {source}") diff --git a/python/tests/expressions/test_literals.py b/python/tests/expressions/test_literals.py index 16aee4dbc35c..309bd28c4cc9 100644 --- a/python/tests/expressions/test_literals.py +++ b/python/tests/expressions/test_literals.py @@ -373,7 +373,7 @@ def test_string_to_uuid_literal() -> None: uuid_str = literal(str(expected)) uuid_lit = uuid_str.to(UUIDType()) - assert expected == uuid_lit.value + assert expected.bytes == uuid_lit.value def test_string_to_decimal_literal() -> None: @@ -503,6 +503,22 @@ def test_binary_to_smaller_fixed_none() -> None: assert "Cannot convert BinaryLiteral into fixed[2], different length: 2 <> 3" in str(e.value) +def test_binary_to_uuid() -> None: + test_uuid = uuid.uuid4() + lit = literal(test_uuid.bytes) + uuid_lit = lit.to(UUIDType()) + assert uuid_lit is not None + assert lit.value == uuid_lit.value + assert uuid_lit.value == test_uuid.bytes + + +def test_incompatible_binary_to_uuid() -> None: + lit = literal(bytes([0x00, 0x01, 0x02])) + with pytest.raises(TypeError) as e: + _ = lit.to(UUIDType()) + assert "Cannot convert BinaryLiteral into uuid, different length: 16 <> 3" in str(e.value) + + def test_fixed_to_binary() -> None: lit = literal(bytes([0x00, 0x01, 0x02])).to(FixedType(3)) binary_lit = lit.to(BinaryType()) @@ -517,6 +533,22 @@ def test_fixed_to_smaller_fixed_none() -> None: assert "Could not convert b'\\x00\\x01\\x02' into a fixed[2]" in str(e.value) +def test_fixed_to_uuid() -> None: + test_uuid = uuid.uuid4() + lit = literal(test_uuid.bytes).to(FixedType(16)) + uuid_lit = lit.to(UUIDType()) + assert uuid_lit is not None + assert lit.value == uuid_lit.value + assert uuid_lit.value == test_uuid.bytes + + +def test_incompatible_fixed_to_uuid() -> None: + lit = literal(bytes([0x00, 0x01, 0x02])).to(FixedType(3)) + with pytest.raises(TypeError) as e: + _ = lit.to(UUIDType()) + assert "Cannot convert BinaryLiteral into uuid, different length: 16 <> 3" in str(e.value) + + def test_above_max_float() -> None: a = FloatAboveMax() # singleton @@ -843,6 +875,13 @@ def test_decimal_literal_dencrement() -> None: assert dec.decrement().value.as_tuple() == Decimal("10.122").as_tuple() +def test_uuid_literal_initialization() -> None: + test_uuid = uuid.UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7") + uuid_literal = literal(test_uuid) + assert isinstance(uuid_literal, Literal) + assert test_uuid.bytes == uuid_literal.value + + # __ __ ___ # | \/ |_ _| _ \_ _ # | |\/| | || | _/ || | @@ -853,7 +892,6 @@ def test_decimal_literal_dencrement() -> None: assert_type(literal(True), Literal[bool]) assert_type(literal(123), Literal[int]) assert_type(literal(123.4), Literal[float]) -assert_type(literal(uuid.UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7")), Literal[uuid.UUID]) assert_type(literal(bytes([0x01, 0x02, 0x03])), Literal[bytes]) assert_type(literal(Decimal("19.25")), Literal[Decimal]) assert_type({literal(1), literal(2), literal(3)}, Set[Literal[int]]) diff --git a/python/tests/test_conversions.py b/python/tests/test_conversions.py index 429de6e01130..3b3e519579fa 100644 --- a/python/tests/test_conversions.py +++ b/python/tests/test_conversions.py @@ -270,9 +270,9 @@ def test_partition_to_py_raise_on_incorrect_precision_or_scale( ( UUIDType(), b"\xf7\x9c>\tg|K\xbd\xa4y?4\x9c\xb7\x85\xe7", - uuid.UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7"), + b"\xf7\x9c>\tg|K\xbd\xa4y?4\x9c\xb7\x85\xe7", ), - (UUIDType(), b"\xf7\x9c>\tg|K\xbd\xa4y?4\x9c\xb7\x85\xe7", uuid.UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7")), + (UUIDType(), b"\xf7\x9c>\tg|K\xbd\xa4y?4\x9c\xb7\x85\xe7", b"\xf7\x9c>\tg|K\xbd\xa4y?4\x9c\xb7\x85\xe7"), (FixedType(3), b"foo", b"foo"), (BinaryType(), b"foo", b"foo"), (DecimalType(5, 2), b"\x30\x39", Decimal("123.45")), @@ -308,9 +308,9 @@ def test_from_bytes(primitive_type: PrimitiveType, b: bytes, result: Any) -> Non ( UUIDType(), b"\xf7\x9c>\tg|K\xbd\xa4y?4\x9c\xb7\x85\xe7", - uuid.UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7"), + b"\xf7\x9c>\tg|K\xbd\xa4y?4\x9c\xb7\x85\xe7", ), - (UUIDType(), b"\xf7\x9c>\tg|K\xbd\xa4y?4\x9c\xb7\x85\xe7", uuid.UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7")), + (UUIDType(), b"\xf7\x9c>\tg|K\xbd\xa4y?4\x9c\xb7\x85\xe7", b"\xf7\x9c>\tg|K\xbd\xa4y?4\x9c\xb7\x85\xe7"), (FixedType(3), b"foo", b"foo"), (BinaryType(), b"foo", b"foo"), (DecimalType(5, 2), b"\x30\x39", Decimal("123.45")), @@ -341,6 +341,22 @@ def test_round_trip_conversion(primitive_type: PrimitiveType, b: bytes, result: assert bytes_from_value == b +@pytest.mark.parametrize( + "primitive_type, v, result", + [ + ( + UUIDType(), + uuid.UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7"), + b"\xf7\x9c>\tg|K\xbd\xa4y?4\x9c\xb7\x85\xe7", + ), + (UUIDType(), uuid.UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7"), b"\xf7\x9c>\tg|K\xbd\xa4y?4\x9c\xb7\x85\xe7"), + ], +) +def test_uuid_to_bytes(primitive_type: PrimitiveType, v: Any, result: bytes) -> None: + bytes_from_value = conversions.to_bytes(primitive_type, v) + assert bytes_from_value == result + + @pytest.mark.parametrize( "primitive_type, b, result", [ diff --git a/python/tests/test_integration.py b/python/tests/test_integration.py index 37ba5b9048be..9a3e044e21d4 100644 --- a/python/tests/test_integration.py +++ b/python/tests/test_integration.py @@ -17,6 +17,7 @@ # pylint:disable=redefined-outer-name import math +import uuid from urllib.parse import urlparse import pyarrow.parquet as pq @@ -27,9 +28,11 @@ from pyiceberg.exceptions import NoSuchTableError from pyiceberg.expressions import ( And, + EqualTo, GreaterThanOrEqual, IsNaN, LessThan, + NotEqualTo, NotNaN, ) from pyiceberg.io.pyarrow import pyarrow_to_schema @@ -315,3 +318,37 @@ def test_partitioned_tables(catalog: Catalog) -> None: table = catalog.load_table(f"default.{table_name}") arrow_table = table.scan(selected_fields=("number",), row_filter=predicate).to_arrow() assert set(arrow_table["number"].to_pylist()) == {5, 6, 7, 8, 9, 10, 11, 12}, f"Table {table_name}, predicate {predicate}" + + +@pytest.mark.integration +def test_unpartitioned_uuid_table(catalog: Catalog) -> None: + unpartitioned_uuid = catalog.load_table("default.test_uuid_and_fixed_unpartitioned") + arrow_table_eq = unpartitioned_uuid.scan(row_filter="uuid_col == '102cb62f-e6f8-4eb0-9973-d9b012ff0967'").to_arrow() + assert arrow_table_eq["uuid_col"].to_pylist() == [uuid.UUID("102cb62f-e6f8-4eb0-9973-d9b012ff0967").bytes] + + arrow_table_neq = unpartitioned_uuid.scan( + row_filter="uuid_col != '102cb62f-e6f8-4eb0-9973-d9b012ff0967' and uuid_col != '639cccce-c9d2-494a-a78c-278ab234f024'" + ).to_arrow() + assert arrow_table_neq["uuid_col"].to_pylist() == [ + uuid.UUID("ec33e4b2-a834-4cc3-8c4a-a1d3bfc2f226").bytes, + uuid.UUID("c1b0d8e0-0b0e-4b1e-9b0a-0e0b0d0c0a0b").bytes, + uuid.UUID("923dae77-83d6-47cd-b4b0-d383e64ee57e").bytes, + ] + + +@pytest.mark.integration +def test_unpartitioned_fixed_table(catalog: Catalog) -> None: + fixed_table = catalog.load_table("default.test_uuid_and_fixed_unpartitioned") + arrow_table_eq = fixed_table.scan(row_filter=EqualTo("fixed_col", b"1234567890123456789012345")).to_arrow() + assert arrow_table_eq["fixed_col"].to_pylist() == [b"1234567890123456789012345"] + + arrow_table_neq = fixed_table.scan( + row_filter=And( + NotEqualTo("fixed_col", b"1234567890123456789012345"), NotEqualTo("uuid_col", "c1b0d8e0-0b0e-4b1e-9b0a-0e0b0d0c0a0b") + ) + ).to_arrow() + assert arrow_table_neq["fixed_col"].to_pylist() == [ + b"1231231231231231231231231", + b"12345678901234567ass12345", + b"qweeqwwqq1231231231231111", + ] diff --git a/python/tests/test_schema.py b/python/tests/test_schema.py index 76dddb6486d3..d3400b62669d 100644 --- a/python/tests/test_schema.py +++ b/python/tests/test_schema.py @@ -21,20 +21,56 @@ import pytest from pyiceberg import schema +from pyiceberg.exceptions import ResolveError from pyiceberg.expressions import Accessor -from pyiceberg.schema import Schema, build_position_accessors, prune_columns +from pyiceberg.schema import ( + Schema, + build_position_accessors, + promote, + prune_columns, +) from pyiceberg.typedef import EMPTY_DICT, StructProtocol from pyiceberg.types import ( + BinaryType, BooleanType, + DateType, + DecimalType, + DoubleType, + FixedType, FloatType, + IcebergType, IntegerType, ListType, + LongType, MapType, NestedField, StringType, StructType, + TimestampType, + TimestamptzType, + TimeType, + UUIDType, ) +TEST_PRIMITIVE_TYPES = [ + BooleanType(), + IntegerType(), + LongType(), + FloatType(), + DoubleType(), + DecimalType(10, 2), + DecimalType(100, 2), + StringType(), + DateType(), + TimeType(), + TimestamptzType(), + TimestampType(), + BinaryType(), + FixedType(16), + FixedType(20), + UUIDType(), +] + def test_schema_str(table_schema_simple: Schema) -> None: """Test casting a schema to a string""" @@ -738,3 +774,37 @@ def test_schema_select_cant_be_found(table_schema_nested: Schema) -> None: with pytest.raises(ValueError) as exc_info: table_schema_nested.select("BAZ", case_sensitive=True) assert "Could not find column: 'BAZ'" in str(exc_info.value) + + +def should_promote(file_type: IcebergType, read_type: IcebergType) -> bool: + if isinstance(file_type, IntegerType) and isinstance(read_type, LongType): + return True + if isinstance(file_type, FloatType) and isinstance(read_type, DoubleType): + return True + if isinstance(file_type, StringType) and isinstance(read_type, BinaryType): + return True + if isinstance(file_type, BinaryType) and isinstance(read_type, StringType): + return True + if isinstance(file_type, DecimalType) and isinstance(read_type, DecimalType): + return file_type.precision <= read_type.precision and file_type.scale == file_type.scale + if isinstance(file_type, FixedType) and isinstance(read_type, UUIDType) and len(file_type) == 16: + return True + return False + + +@pytest.mark.parametrize( + "file_type", + TEST_PRIMITIVE_TYPES, +) +@pytest.mark.parametrize( + "read_type", + TEST_PRIMITIVE_TYPES, +) +def test_promotion(file_type: IcebergType, read_type: IcebergType) -> None: + if file_type == read_type: + return + if should_promote(file_type, read_type): + assert promote(file_type, read_type) == read_type + else: + with pytest.raises(ResolveError): + promote(file_type, read_type) diff --git a/python/tests/test_transforms.py b/python/tests/test_transforms.py index 5c4980d8ac42..8d2fe1990528 100644 --- a/python/tests/test_transforms.py +++ b/python/tests/test_transforms.py @@ -117,6 +117,7 @@ (b"\x00\x01\x02\x03", FixedType(4), -188683207), ("iceberg", StringType(), 1210000089), (UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7"), UUIDType(), 1488055340), + (b"\xf7\x9c>\tg|K\xbd\xa4y?4\x9c\xb7\x85\xe7", UUIDType(), 1488055340), ], ) def test_bucket_hash_values(test_input: Any, test_type: PrimitiveType, expected: Any) -> None: @@ -138,6 +139,11 @@ def test_bucket_hash_values(test_input: Any, test_type: PrimitiveType, expected: UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7"), 40, ), + ( + BucketTransform(100).transform(UUIDType()), + b"\xf7\x9c>\tg|K\xbd\xa4y?4\x9c\xb7\x85\xe7", + 40, + ), (BucketTransform(128).transform(FixedType(3)), b"foo", 32), (BucketTransform(128).transform(BinaryType()), b"\x00\x01\x02\x03", 57), ],