From 51f8fb0d445323505be5c1b2c0d518515a5d6462 Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Wed, 24 Jul 2024 12:44:38 -0700 Subject: [PATCH 01/17] added support for interval types in the variant spec --- .../resources/error/error-conditions.json | 6 + .../spark/util/DayTimeIntervalUtils.java | 114 ++++++++ .../spark/util/YearMonthIntervalUtils.java | 31 +++ common/variant/README.md | 46 ++-- .../apache/spark/types/variant/Variant.java | 34 +++ .../spark/types/variant/VariantBuilder.java | 16 ++ .../spark/types/variant/VariantUtil.java | 69 ++++- python/pyspark/sql/tests/test_types.py | 255 ++++++++++++++++++ python/pyspark/sql/variant_utils.py | 223 ++++++++++++++- .../variant/VariantExpressionEvalUtils.scala | 4 + .../variant/variantExpressions.scala | 21 +- .../variant/VariantExpressionSuite.scala | 215 ++++++++++++++- streaming/pom.xml | 2 +- 13 files changed, 1006 insertions(+), 30 deletions(-) create mode 100644 common/utils/src/main/scala/org/apache/spark/util/DayTimeIntervalUtils.java create mode 100644 common/utils/src/main/scala/org/apache/spark/util/YearMonthIntervalUtils.java diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 44f0a59a4b48e..a6f22de5eebe5 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -4186,6 +4186,12 @@ ], "sqlState" : "42846" }, + "UNKNOWN_PRIMITIVE_TYPE_IN_VARIANT" : { + "message" : [ + "Unknown primitive type with id was found in a variant value. The type might be supported in a newer version." + ], + "sqlState" : "22023" + }, "UNKNOWN_PROTOBUF_MESSAGE_TYPE" : { "message" : [ "Attempting to treat as a Message, but it was ." diff --git a/common/utils/src/main/scala/org/apache/spark/util/DayTimeIntervalUtils.java b/common/utils/src/main/scala/org/apache/spark/util/DayTimeIntervalUtils.java new file mode 100644 index 0000000000000..83c4d0375b6c4 --- /dev/null +++ b/common/utils/src/main/scala/org/apache/spark/util/DayTimeIntervalUtils.java @@ -0,0 +1,114 @@ +package org.apache.spark.util; + +import org.apache.spark.SparkException; + +import java.math.BigDecimal; +import java.util.ArrayList; + +// Replicating code from SparkIntervalUtils so code in the 'common' space can work with +// year-month intervals. +public class DayTimeIntervalUtils { + private static byte DAY = 0; + private static byte HOUR = 1; + private static byte MINUTE = 2; + private static byte SECOND = 3; + private static long HOURS_PER_DAY = 24; + private static long MINUTES_PER_HOUR = 60; + private static long SECONDS_PER_MINUTE = 60; + private static long MILLIS_PER_SECOND = 1000; + private static long MICROS_PER_MILLIS = 1000; + private static long MICROS_PER_SECOND = MICROS_PER_MILLIS * MILLIS_PER_SECOND; + private static long MICROS_PER_MINUTE = SECONDS_PER_MINUTE * MICROS_PER_SECOND; + private static long MICROS_PER_HOUR = MINUTES_PER_HOUR * MICROS_PER_MINUTE; + private static long MICROS_PER_DAY = HOURS_PER_DAY * MICROS_PER_HOUR; + private static long MAX_DAY = Long.MAX_VALUE / MICROS_PER_DAY; + private static long MAX_HOUR = Long.MAX_VALUE / MICROS_PER_HOUR; + private static long MAX_MINUTE = Long.MAX_VALUE / MICROS_PER_MINUTE; + private static long MAX_SECOND = Long.MAX_VALUE / MICROS_PER_SECOND; + + public static String fieldToString(byte field) throws SparkException { + if (field == DAY) { + return "DAY"; + } else if (field == HOUR) { + return "HOUR"; + } else if (field == MINUTE) { + return "MINUTE"; + } else if (field == SECOND) { + return "SECOND"; + } else { + throw new SparkException("Invalid field in day-time interval: " + field + + ". Supported fields are: DAY, HOUR, MINUTE, SECOND"); + } + } + + public static String toDayTimeIntervalANSIString(long micros, byte startField, byte endField) + throws SparkException { + String sign = ""; + long rest = micros; + try { + String from = fieldToString(startField).toUpperCase(); + String to = fieldToString(endField).toUpperCase(); + String prefix = "INTERVAL '"; + String postfix = startField == endField ? "' " + from : "' " + from + " TO " + to; + if (micros < 0) { + if (micros == Long.MIN_VALUE) { + // Especial handling of minimum `Long` value because negate op overflows `Long`. + // seconds = 106751991 * (24 * 60 * 60) + 4 * 60 * 60 + 54 = 9223372036854 + // microseconds = -9223372036854000000L-775808 == Long.MinValue + String baseStr = "-106751991 04:00:54.775808000"; + String firstStr = "-" + (startField == DAY ? Long.toString(MAX_DAY) : + (startField == HOUR ? Long.toString(MAX_HOUR) : + (startField == MINUTE ? Long.toString(MAX_MINUTE) : + Long.toString(MAX_SECOND) + ".775808"))); + if (startField == endField) { + return prefix + firstStr + postfix; + } else { + int substrStart = startField == DAY ? 10 : (startField == HOUR ? 13 : 16); + int substrEnd = endField == HOUR ? 13 : (endField == MINUTE ? 16 : 26); + return prefix + firstStr + baseStr.substring(substrStart, substrEnd) + postfix; + } + } else { + sign = "-"; + rest = -rest; + } + } + StringBuilder formatBuilder = new StringBuilder(sign); + ArrayList formatArgs = new ArrayList<>(); + if (startField == DAY) { + formatBuilder.append(rest / MICROS_PER_DAY); + rest %= MICROS_PER_DAY; + } else if (startField == HOUR) { + formatBuilder.append("%02d"); + formatArgs.add(rest / MICROS_PER_HOUR); + rest %= MICROS_PER_HOUR; + } else if (startField == MINUTE) { + formatBuilder.append("%02d"); + formatArgs.add(rest / MICROS_PER_MINUTE); + rest %= MICROS_PER_MINUTE; + } else if (startField == SECOND) { + String leadZero = rest < 10 * MICROS_PER_SECOND ? "0" : ""; + formatBuilder.append(leadZero + BigDecimal.valueOf(rest, 6) + .stripTrailingZeros().toPlainString()); + } + + if (startField < HOUR && HOUR <= endField) { + formatBuilder.append(" %02d"); + formatArgs.add(rest / MICROS_PER_HOUR); + rest %= MICROS_PER_HOUR; + } + if (startField < MINUTE && MINUTE <= endField) { + formatBuilder.append(":%02d"); + formatArgs.add(rest / MICROS_PER_MINUTE); + rest %= MICROS_PER_MINUTE; + } + if (startField < SECOND && SECOND <= endField) { + String leadZero = rest < 10 * MICROS_PER_SECOND ? "0" : ""; + formatBuilder.append(":" + leadZero + BigDecimal.valueOf(rest, 6) + .stripTrailingZeros().toPlainString()); + } + return prefix + String.format(formatBuilder.toString(), formatArgs.toArray()) + postfix; + } catch (SparkException e) { + throw e; + } + } +} diff --git a/common/utils/src/main/scala/org/apache/spark/util/YearMonthIntervalUtils.java b/common/utils/src/main/scala/org/apache/spark/util/YearMonthIntervalUtils.java new file mode 100644 index 0000000000000..84d8bd25090c1 --- /dev/null +++ b/common/utils/src/main/scala/org/apache/spark/util/YearMonthIntervalUtils.java @@ -0,0 +1,31 @@ +package org.apache.spark.util; + +// Replicating code from SparkIntervalUtils so code in the 'common' space can work with +// year-month intervals. +public class YearMonthIntervalUtils { + private static byte YEAR = 0; + private static byte MONTH = 1; + private static int MONTHS_PER_YEAR = 12; + + public static String toYearMonthIntervalANSIString(int months, byte startField, byte endField) { + String sign = ""; + long absMonths = months; + if (months < 0) { + sign = "-"; + absMonths = -absMonths; + } + String year = sign + Long.toString(absMonths / MONTHS_PER_YEAR); + String yearAndMonth = year + "-" + Long.toString(absMonths % MONTHS_PER_YEAR); + StringBuilder formatBuilder = new StringBuilder("INTERVAL '"); + if (startField == endField) { + if (startField == YEAR) { + formatBuilder.append(year + "' YEAR"); + } else { + formatBuilder.append(Integer.toString(months) + "' MONTH"); + } + } else { + formatBuilder.append(yearAndMonth + "' YEAR TO MONTH"); + } + return formatBuilder.toString(); + } +} diff --git a/common/variant/README.md b/common/variant/README.md index 3e1b00c494755..12c14361941cd 100644 --- a/common/variant/README.md +++ b/common/variant/README.md @@ -335,27 +335,29 @@ The Decimal type contains a scale, but no precision. The implied precision of a | Object | `2` | A collection of (string-key, variant-value) pairs | | Array | `3` | An ordered sequence of variant values | -| Primitive Type | Type ID | Equivalent Parquet Type | Binary format | -|-----------------------------|---------|---------------------------|-----------------------------------------------------------------------------------------------------------| -| null | `0` | any | none | -| boolean (True) | `1` | BOOLEAN | none | -| boolean (False) | `2` | BOOLEAN | none | -| int8 | `3` | INT(8, signed) | 1 byte | -| int16 | `4` | INT(16, signed) | 2 byte little-endian | -| int32 | `5` | INT(32, signed) | 4 byte little-endian | -| int64 | `6` | INT(64, signed) | 8 byte little-endian | -| double | `7` | DOUBLE | IEEE little-endian | -| decimal4 | `8` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | -| decimal8 | `9` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | -| decimal16 | `10` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | -| date | `11` | DATE | 4 byte little-endian | -| timestamp | `12` | TIMESTAMP(true, MICROS) | 8-byte little-endian | -| timestamp without time zone | `13` | TIMESTAMP(false, MICROS) | 8-byte little-endian | -| float | `14` | FLOAT | IEEE little-endian | -| binary | `15` | BINARY | 4 byte little-endian size, followed by bytes | -| string | `16` | STRING | 4 byte little-endian size, followed by UTF-8 encoded bytes | -| binary from metadata | `17` | BINARY | Little-endian index into the metadata dictionary. Number of bytes is equal to the metadata `offset_size`. | -| string from metadata | `18` | STRING | Little-endian index into the metadata dictionary. Number of bytes is equal to the metadata `offset_size`. | +| Primitive Type | Type ID | Equivalent Parquet Type | Binary format | +|-----------------------------|---------|-----------------------------------------------|---------------------------------------------------------------------------------------------------------------------| +| null | `0` | any | none | +| boolean (True) | `1` | BOOLEAN | none | +| boolean (False) | `2` | BOOLEAN | none | +| int8 | `3` | INT(8, signed) | 1 byte | +| int16 | `4` | INT(16, signed) | 2 byte little-endian | +| int32 | `5` | INT(32, signed) | 4 byte little-endian | +| int64 | `6` | INT(64, signed) | 8 byte little-endian | +| double | `7` | DOUBLE | IEEE little-endian | +| decimal4 | `8` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | +| decimal8 | `9` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | +| decimal16 | `10` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | +| date | `11` | DATE | 4 byte little-endian | +| timestamp | `12` | TIMESTAMP(true, MICROS) | 8-byte little-endian | +| timestamp without time zone | `13` | TIMESTAMP(false, MICROS) | 8-byte little-endian | +| float | `14` | FLOAT | IEEE little-endian | +| binary | `15` | BINARY | 4 byte little-endian size, followed by bytes | +| string | `16` | STRING | 4 byte little-endian size, followed by UTF-8 encoded bytes | +| binary from metadata | `17` | BINARY | Little-endian index into the metadata dictionary. Number of bytes is equal to the metadata `offset_size`. | +| string from metadata | `18` | STRING | Little-endian index into the metadata dictionary. Number of bytes is equal to the metadata `offset_size`. | +| year-month interval | `19` | YearMonthIntervalType(start_field, end_field) | 1 byte denoting start field (1 bit) and end field (1 bit) starting at LSB followed by 4-byte little-endian value. | +| day-time interval | `20` | DayTimeIntervalType(start_field, end_field) | 1 byte denoting start field (2 bits) and end field (2 bits) starting at LSB followed by 8-byte little-endian value. | | Decimal Precision | Decimal value type | |-----------------------|--------------------| @@ -364,6 +366,8 @@ The Decimal type contains a scale, but no precision. The implied precision of a | 18 <= precision <= 38 | int128 | | > 38 | Not supported | +The year-month and day-time interval types have one byte at the beginning indicating the start and end fields. In the case of the year-month interval, the least significant bit denotes the start field and the next least significant bit denotes the end field. The remaining 6 bits are unused. A field value of 0 represents YEAR and 1 represents MONTH. In the case of the day-time interval, the least significant 2 bits denote the start field and the next least significant 2 bits denote the end field. The remaining 4 bits are unused. A field value of 0 represents DAY, 1 represents HOUR, 2 represents MINUTE, and 3 represents SECOND. + # Field ID order and uniqueness For objects, field IDs and offsets must be listed in the order of the corresponding field names, sorted lexicographically. Note that the fields themselves are not required to follow this order. As a result, offsets will not necessarily be listed in ascending order. diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java index a705daaf323b2..ecb25c51aa0d1 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java @@ -34,6 +34,9 @@ import java.util.Base64; import java.util.Locale; +import org.apache.spark.util.DayTimeIntervalUtils; +import org.apache.spark.util.YearMonthIntervalUtils; + import static org.apache.spark.types.variant.VariantUtil.*; /** @@ -88,6 +91,16 @@ public long getLong() { return VariantUtil.getLong(value, pos); } + // Get the start and end fields of a year-month interval from the variant. + public IntervalFields getYearMonthIntervalFields() { + return VariantUtil.getYearMonthIntervalFields(value, pos); + } + + // Get the start and end fields of a day-time interval from the variant. + public IntervalFields getDayTimeIntervalFields() { + return VariantUtil.getDayTimeIntervalFields(value, pos); + } + // Get a double value from the variant. public double getDouble() { return VariantUtil.getDouble(value, pos); @@ -113,6 +126,11 @@ public String getString() { return VariantUtil.getString(value, pos); } + // Get the type info bits from a variant value. + public int getTypeInfo() { + return VariantUtil.getTypeInfo(value, pos); + } + // Get the value type of the variant. public Type getType() { return VariantUtil.getType(value, pos); @@ -316,6 +334,22 @@ static void toJsonImpl(byte[] value, byte[] metadata, int pos, StringBuilder sb, case BINARY: appendQuoted(sb, Base64.getEncoder().encodeToString(VariantUtil.getBinary(value, pos))); break; + case YEAR_MONTH_INTERVAL: + IntervalFields ymFields = VariantUtil.getYearMonthIntervalFields(value, pos); + int ymValue = (int) VariantUtil.getLong(value, pos); + appendQuoted(sb, YearMonthIntervalUtils + .toYearMonthIntervalANSIString(ymValue, ymFields.startField, ymFields.endField)); + break; + case DAY_TIME_INTERVAL: + IntervalFields dtFields = VariantUtil.getDayTimeIntervalFields(value, pos); + long dtValue = VariantUtil.getLong(value, pos); + try { + appendQuoted(sb, DayTimeIntervalUtils.toDayTimeIntervalANSIString(dtValue, + dtFields.startField, dtFields.endField)); + } catch(Exception e) { + throw malformedVariant(); + } + break; } } } diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java index 2afba81d192e9..f603b787da12d 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java @@ -214,6 +214,22 @@ public void appendTimestampNtz(long microsSinceEpoch) { writePos += 8; } + public void appendYearMonthInterval(long value, byte startField, byte endField) { + checkCapacity(1 + 5); + writeBuffer[writePos++] = primitiveHeader(YEAR_MONTH_INTERVAL); + writeBuffer[writePos++] = (byte) (startField | (endField << 1)); + writeLong(writeBuffer, writePos, value, 4); + writePos += 4; + } + + public void appendDayTimeInterval(long value, byte startField, byte endField) { + checkCapacity(1 + 9); + writeBuffer[writePos++] = primitiveHeader(DAY_TIME_INTERVAL); + writeBuffer[writePos++] = (byte) (startField | (endField << 2)); + writeLong(writeBuffer, writePos, value, 8); + writePos += 8; + } + public void appendFloat(float f) { checkCapacity(1 + 4); writeBuffer[writePos++] = primitiveHeader(FLOAT); diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java index 84e3a45e4b0ee..3177c81223186 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java @@ -120,6 +120,12 @@ public class VariantUtil { // Long string value. The content is (4-byte little-endian unsigned integer representing the // string size) + (size bytes of string content). public static final int LONG_STR = 16; + // year-month interval value. The content is one byte representing the start and end field values + // (1 bit each starting at least significant bits) and a 4-byte little-endian signed integer + public static final int YEAR_MONTH_INTERVAL = 19; + // day-time interval value. The content is one byte representing the start and end field values + // (2 bits each starting at least significant bits) and an 8-byte little-endian signed integer + public static final int DAY_TIME_INTERVAL = 20; public static final byte VERSION = 1; // The lower 4 bits of the first metadata byte contain the version. @@ -171,6 +177,12 @@ static SparkRuntimeException malformedVariant() { Map$.MODULE$.empty(), null, new QueryContext[]{}, ""); } + static SparkRuntimeException unknownPrimitiveTypeInVariant(int id) { + return new SparkRuntimeException("UNKNOWN_PRIMITIVE_TYPE_IN_VARIANT", + new scala.collection.immutable.Map.Map1<>("id", Integer.toString(id)), null, + new QueryContext[]{}, ""); + } + // An exception indicating that an external caller tried to call the Variant constructor with // value or metadata exceeding the 16MiB size limit. We will never construct a Variant this large, // so it should only be possible to encounter this exception when reading a Variant produced by @@ -233,6 +245,13 @@ public enum Type { TIMESTAMP_NTZ, FLOAT, BINARY, + YEAR_MONTH_INTERVAL, + DAY_TIME_INTERVAL, + } + + public static int getTypeInfo(byte[] value, int pos) { + checkIndex(pos, value.length); + return (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK; } // Get the value type of variant value `value[pos...]`. It is only legal to call `get*` if @@ -280,8 +299,12 @@ public static Type getType(byte[] value, int pos) { return Type.BINARY; case LONG_STR: return Type.STRING; + case YEAR_MONTH_INTERVAL: + return Type.YEAR_MONTH_INTERVAL; + case DAY_TIME_INTERVAL: + return Type.DAY_TIME_INTERVAL; default: - throw malformedVariant(); + throw unknownPrimitiveTypeInVariant(typeInfo); } } } @@ -322,8 +345,10 @@ public static int valueSize(byte[] value, int pos) { case TIMESTAMP: case TIMESTAMP_NTZ: return 9; + case YEAR_MONTH_INTERVAL: case DECIMAL4: return 6; + case DAY_TIME_INTERVAL: case DECIMAL8: return 10; case DECIMAL16: @@ -332,7 +357,7 @@ public static int valueSize(byte[] value, int pos) { case LONG_STR: return 1 + U32_SIZE + readUnsigned(value, pos + 1, U32_SIZE); default: - throw malformedVariant(); + throw unknownPrimitiveTypeInVariant(typeInfo); } } } @@ -377,11 +402,51 @@ public static long getLong(byte[] value, int pos) { case TIMESTAMP: case TIMESTAMP_NTZ: return readLong(value, pos + 1, 8); + case YEAR_MONTH_INTERVAL: + return readLong(value, pos + 2, 4); + case DAY_TIME_INTERVAL: + return readLong(value, pos + 2, 8); default: throw new IllegalStateException(exceptionMessage); } } + public static class IntervalFields { + public IntervalFields(byte startField, byte endField) { + this.startField = startField; + this.endField = endField; + } + + public byte startField; + public byte endField; + } + + // Get the start and end fields of a variant value representing a year-month interval value. The + // returned array contains the start field at the zeroth index and the end field at the first + // index. + public static IntervalFields getYearMonthIntervalFields(byte[] value, int pos) { + long fieldInfo = readLong(value, pos + 1, 1); + IntervalFields intervalFields = new IntervalFields((byte) (fieldInfo & 0x1), + (byte) ((fieldInfo >> 1) & 0x1)); + if (intervalFields.endField < intervalFields.startField) { + throw malformedVariant(); + } + return intervalFields; + } + + // Get the start and end fields of a variant value representing a day time interval value. The + // returned array contains the start field at the zeroth index and the end field at the first + // index. + public static IntervalFields getDayTimeIntervalFields(byte[] value, int pos) { + long fieldInfo = readLong(value, pos + 1, 1); + IntervalFields intervalFields = new IntervalFields((byte) (fieldInfo & 0x3), + (byte) ((fieldInfo >> 2) & 0x3)); + if (intervalFields.endField < intervalFields.startField) { + throw malformedVariant(); + } + return intervalFields; + } + // Get a double value from variant value `value[pos...]`. // Throw `MALFORMED_VARIANT` if the variant is malformed. public static double getDouble(byte[] value, int pos) { diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 69afc2a5b5ed3..9482fc35671fc 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -2125,6 +2125,100 @@ def test_variant_type(self): + " as timestamp) as variant) as t1, cast(cast('0001-12-31 01:01:01+08:00'" + " as timestamp) as variant) as t2" ).collect()[0] + # Highest possible DT interval value + high_dt_interval_columns = self.spark.sql( + "select 9223372036854.775807::interval day to second::variant as dti00, " + + "9223372036854.775807::interval hour to second::variant as dti01, " + + "9223372036854.775807::interval minute to second::variant as dti02, " + + "9223372036854.775807::interval second::variant as dti03, " + + "153722867280.912930::interval day to minute::variant as dti10, " + + "153722867280.912930::interval hour to minute::variant as dti11, " + + "153722867280.912930::interval minute::variant as dti12, " + + "2562047788.015215::interval day to hour::variant as dti20, " + + "2562047788.015215::interval hour::variant as dti21, " + + "106751991.167300::interval day::variant as dti30" + ).collect()[0] + # Lowest possible DT interval value + low_dt_interval_columns = self.spark.sql( + "select -9223372036854.775808::interval day to second::variant as dti00, " + + "-9223372036854.775808::interval hour to second::variant as dti01, " + + "-9223372036854.775808::interval minute to second::variant as dti02, " + + "-9223372036854.775808::interval second::variant as dti03, " + + "-153722867280.912930::interval day to minute::variant as dti10, " + + "-153722867280.912930::interval hour to minute::variant as dti11, " + + "-153722867280.912930::interval minute::variant as dti12, " + + "-2562047788.015215::interval day to hour::variant as dti20, " + + "-2562047788.015215::interval hour::variant as dti21, " + + "-106751991.167300::interval day::variant as dti30" + ).collect()[0] + zero_dt_interval_columns = self.spark.sql( + "select 0::interval day to second::variant as dti00, " + + "0::interval hour to second::variant as dti01, " + + "0::interval minute to second::variant as dti02, " + + "0::interval second::variant as dti03, " + + "0::interval day to minute::variant as dti10, " + + "0::interval hour to minute::variant as dti11, " + + "0::interval minute::variant as dti12, " + + "0::interval day to hour::variant as dti20, " + + "0::interval hour::variant as dti21, " + + "0::interval day::variant as dti30" + ).collect()[0] + # Random positive dt interval value + rand_pos_dt_interval_columns = self.spark.sql( + "select 12893121435::interval day to second::variant as dti00, " + + "273457447832::interval hour to second::variant as dti01, " + + "234233247::interval minute to second::variant as dti02, " + + "9310354::interval second::variant as dti03, " + + "214885357::interval day to minute::variant as dti10, " + + "4557624130::interval hour to minute::variant as dti11, " + + "3903887::interval minute::variant as dti12, " + + "3581422::interval day to hour::variant as dti20, " + + "75960402::interval hour::variant as dti21, " + + "65064::interval day::variant as dti30" + ).collect()[0] + # Random negative dt interval value + rand_neg_dt_interval_columns = self.spark.sql( + "select -426547473652::interval day to second::variant as dti00, " + + "-2327834334::interval hour to second::variant as dti01, " + + "-324223232::interval minute to second::variant as dti02, " + + "-2342332::interval second::variant as dti03, " + + "-7109124560::interval day to minute::variant as dti10, " + + "-38797238::interval hour to minute::variant as dti11, " + + "-5403720::interval minute::variant as dti12, " + + "-118485409::interval day to hour::variant as dti20, " + + "-646620::interval hour::variant as dti21, " + + "-90062::interval day::variant as dti30" + ).collect()[0] + + # Highest possible ym interval value + high_ym_interval_columns = self.spark.sql( + "select 2147483647::interval year to month::variant as ymi0, " + + "2147483647::interval month::variant as ymi1, " + + "178956970::interval year::variant as ymi2" + ).collect()[0] + # Lowest possible ym interval value + low_ym_interval_columns = self.spark.sql( + "select -2147483648::interval year to month::variant as ymi0, " + + "-2147483648::interval month::variant as ymi1, " + + "-178956970::interval year::variant as ymi2" + ).collect()[0] + zero_ym_interval_columns = self.spark.sql( + "select 0::interval year to month::variant ymi0, " + + "0::interval month::variant ymi1, " + + "0::interval year::variant ymi2" + ).collect()[0] + # Random positive ym interval value + rand_pos_ym_interval_columns = self.spark.sql( + "select 24678537::interval year to month::variant ymi0, " + + "345763467::interval month::variant ymi1, " + + "45723888::interval year::variant ymi2" + ).collect()[0] + # Random negative ym interval value + rand_neg_ym_interval_columns = self.spark.sql( + "select -425245345::interval year to month::variant ymi0, " + + "-849348229::interval month::variant ymi1, " + + "-85349890::interval year::variant ymi2" + ).collect()[0] variants = [ row["v"], @@ -2142,6 +2236,71 @@ def test_variant_type(self): timetamp_columns["t0"], timetamp_columns["t1"], timetamp_columns["t2"], + high_dt_interval_columns["dti00"], + high_dt_interval_columns["dti01"], + high_dt_interval_columns["dti02"], + high_dt_interval_columns["dti03"], + high_dt_interval_columns["dti10"], + high_dt_interval_columns["dti11"], + high_dt_interval_columns["dti12"], + high_dt_interval_columns["dti20"], + high_dt_interval_columns["dti21"], + high_dt_interval_columns["dti30"], + low_dt_interval_columns["dti00"], + low_dt_interval_columns["dti01"], + low_dt_interval_columns["dti02"], + low_dt_interval_columns["dti03"], + low_dt_interval_columns["dti10"], + low_dt_interval_columns["dti11"], + low_dt_interval_columns["dti12"], + low_dt_interval_columns["dti20"], + low_dt_interval_columns["dti21"], + low_dt_interval_columns["dti30"], + zero_dt_interval_columns["dti00"], + zero_dt_interval_columns["dti01"], + zero_dt_interval_columns["dti02"], + zero_dt_interval_columns["dti03"], + zero_dt_interval_columns["dti10"], + zero_dt_interval_columns["dti11"], + zero_dt_interval_columns["dti12"], + zero_dt_interval_columns["dti20"], + zero_dt_interval_columns["dti21"], + zero_dt_interval_columns["dti30"], + rand_pos_dt_interval_columns["dti00"], + rand_pos_dt_interval_columns["dti01"], + rand_pos_dt_interval_columns["dti02"], + rand_pos_dt_interval_columns["dti03"], + rand_pos_dt_interval_columns["dti10"], + rand_pos_dt_interval_columns["dti11"], + rand_pos_dt_interval_columns["dti12"], + rand_pos_dt_interval_columns["dti20"], + rand_pos_dt_interval_columns["dti21"], + rand_pos_dt_interval_columns["dti30"], + rand_neg_dt_interval_columns["dti00"], + rand_neg_dt_interval_columns["dti01"], + rand_neg_dt_interval_columns["dti02"], + rand_neg_dt_interval_columns["dti03"], + rand_neg_dt_interval_columns["dti10"], + rand_neg_dt_interval_columns["dti11"], + rand_neg_dt_interval_columns["dti12"], + rand_neg_dt_interval_columns["dti20"], + rand_neg_dt_interval_columns["dti21"], + rand_neg_dt_interval_columns["dti30"], + high_ym_interval_columns["ymi0"], + high_ym_interval_columns["ymi1"], + high_ym_interval_columns["ymi2"], + low_ym_interval_columns["ymi0"], + low_ym_interval_columns["ymi1"], + low_ym_interval_columns["ymi2"], + zero_ym_interval_columns["ymi0"], + zero_ym_interval_columns["ymi1"], + zero_ym_interval_columns["ymi2"], + rand_pos_ym_interval_columns["ymi0"], + rand_pos_ym_interval_columns["ymi1"], + rand_pos_ym_interval_columns["ymi2"], + rand_neg_ym_interval_columns["ymi0"], + rand_neg_ym_interval_columns["ymi1"], + rand_neg_ym_interval_columns["ymi2"], ] for v in variants: @@ -2165,6 +2324,77 @@ def test_variant_type(self): self.assertEqual(str(variants[12]), '"1940-01-01 05:05:13.123000+00:00"') self.assertEqual(str(variants[13]), '"2522-12-31 05:23:00+00:00"') self.assertEqual(str(variants[14]), '"0001-12-30 17:01:01+00:00"') + self.assertEqual(str(variants[15]), + '"INTERVAL \'106751991 04:00:54.775807\' DAY TO SECOND"') + self.assertEqual(str(variants[16]), + '"INTERVAL \'2562047788:00:54.775807\' HOUR TO SECOND"') + self.assertEqual(str(variants[17]), + '"INTERVAL \'153722867280:54.775807\' MINUTE TO SECOND"') + self.assertEqual(str(variants[18]), '"INTERVAL \'9223372036854.775807\' SECOND"') + self.assertEqual(str(variants[19]), '"INTERVAL \'106751991 04:00\' DAY TO MINUTE"') + self.assertEqual(str(variants[20]), '"INTERVAL \'2562047788:00\' HOUR TO MINUTE"') + self.assertEqual(str(variants[21]), '"INTERVAL \'153722867280\' MINUTE"') + self.assertEqual(str(variants[22]), '"INTERVAL \'106751991 04\' DAY TO HOUR"') + self.assertEqual(str(variants[23]), '"INTERVAL \'2562047788\' HOUR"') + self.assertEqual(str(variants[24]), '"INTERVAL \'106751991\' DAY"') + self.assertEqual(str(variants[25]), + '"INTERVAL \'-106751991 04:00:54.775808\' DAY TO SECOND"') + self.assertEqual(str(variants[26]), + '"INTERVAL \'-2562047788:00:54.775808\' HOUR TO SECOND"') + self.assertEqual(str(variants[27]), + '"INTERVAL \'-153722867280:54.775808\' MINUTE TO SECOND"') + self.assertEqual(str(variants[28]), '"INTERVAL \'-9223372036854.775808\' SECOND"') + self.assertEqual(str(variants[29]), '"INTERVAL \'-106751991 04:00\' DAY TO MINUTE"') + self.assertEqual(str(variants[30]), '"INTERVAL \'-2562047788:00\' HOUR TO MINUTE"') + self.assertEqual(str(variants[31]), '"INTERVAL \'-153722867280\' MINUTE"') + self.assertEqual(str(variants[32]), '"INTERVAL \'-106751991 04\' DAY TO HOUR"') + self.assertEqual(str(variants[33]), '"INTERVAL \'-2562047788\' HOUR"') + self.assertEqual(str(variants[34]), '"INTERVAL \'-106751991\' DAY"') + self.assertEqual(str(variants[35]), '"INTERVAL \'0 00:00:00\' DAY TO SECOND"') + self.assertEqual(str(variants[36]), '"INTERVAL \'00:00:00\' HOUR TO SECOND"') + self.assertEqual(str(variants[37]), '"INTERVAL \'00:00\' MINUTE TO SECOND"') + self.assertEqual(str(variants[38]), '"INTERVAL \'00\' SECOND"') + self.assertEqual(str(variants[39]), '"INTERVAL \'0 00:00\' DAY TO MINUTE"') + self.assertEqual(str(variants[40]), '"INTERVAL \'00:00\' HOUR TO MINUTE"') + self.assertEqual(str(variants[41]), '"INTERVAL \'00\' MINUTE"') + self.assertEqual(str(variants[42]), '"INTERVAL \'0 00\' DAY TO HOUR"') + self.assertEqual(str(variants[43]), '"INTERVAL \'00\' HOUR"') + self.assertEqual(str(variants[44]), '"INTERVAL \'0\' DAY"') + self.assertEqual(str(variants[45]), '"INTERVAL \'149225 22:37:15\' DAY TO SECOND"') + self.assertEqual(str(variants[46]), '"INTERVAL \'75960402:10:32\' HOUR TO SECOND"') + self.assertEqual(str(variants[47]), '"INTERVAL \'3903887:27\' MINUTE TO SECOND"') + self.assertEqual(str(variants[48]), '"INTERVAL \'9310354\' SECOND"') + self.assertEqual(str(variants[49]), '"INTERVAL \'149225 22:37\' DAY TO MINUTE"') + self.assertEqual(str(variants[50]), '"INTERVAL \'75960402:10\' HOUR TO MINUTE"') + self.assertEqual(str(variants[51]), '"INTERVAL \'3903887\' MINUTE"') + self.assertEqual(str(variants[52]), '"INTERVAL \'149225 22\' DAY TO HOUR"') + self.assertEqual(str(variants[53]), '"INTERVAL \'75960402\' HOUR"') + self.assertEqual(str(variants[54]), '"INTERVAL \'65064\' DAY"') + self.assertEqual(str(variants[55]), '"INTERVAL \'-4936892 01:20:52\' DAY TO SECOND"') + self.assertEqual(str(variants[56]), '"INTERVAL \'-646620:38:54\' HOUR TO SECOND"') + self.assertEqual(str(variants[57]), '"INTERVAL \'-5403720:32\' MINUTE TO SECOND"') + self.assertEqual(str(variants[58]), '"INTERVAL \'-2342332\' SECOND"') + self.assertEqual(str(variants[59]), '"INTERVAL \'-4936892 01:20\' DAY TO MINUTE"') + self.assertEqual(str(variants[60]), '"INTERVAL \'-646620:38\' HOUR TO MINUTE"') + self.assertEqual(str(variants[61]), '"INTERVAL \'-5403720\' MINUTE"') + self.assertEqual(str(variants[62]), '"INTERVAL \'-4936892 01\' DAY TO HOUR"') + self.assertEqual(str(variants[63]), '"INTERVAL \'-646620\' HOUR"') + self.assertEqual(str(variants[64]), '"INTERVAL \'-90062\' DAY"') + self.assertEqual(str(variants[65]), '"INTERVAL \'178956970-7\' YEAR TO MONTH"') + self.assertEqual(str(variants[66]), '"INTERVAL \'2147483647\' MONTH"') + self.assertEqual(str(variants[67]), '"INTERVAL \'178956970\' YEAR"') + self.assertEqual(str(variants[68]), '"INTERVAL \'-178956970-8\' YEAR TO MONTH"') + self.assertEqual(str(variants[69]), '"INTERVAL \'-2147483648\' MONTH"') + self.assertEqual(str(variants[70]), '"INTERVAL \'-178956970\' YEAR"') + self.assertEqual(str(variants[71]), '"INTERVAL \'0-0\' YEAR TO MONTH"') + self.assertEqual(str(variants[72]), '"INTERVAL \'0\' MONTH"') + self.assertEqual(str(variants[73]), '"INTERVAL \'0\' YEAR"') + self.assertEqual(str(variants[74]), '"INTERVAL \'2056544-9\' YEAR TO MONTH"') + self.assertEqual(str(variants[75]), '"INTERVAL \'345763467\' MONTH"') + self.assertEqual(str(variants[76]), '"INTERVAL \'45723888\' YEAR"') + self.assertEqual(str(variants[77]), '"INTERVAL \'-35437112-1\' YEAR TO MONTH"') + self.assertEqual(str(variants[78]), '"INTERVAL \'-849348229\' MONTH"') + self.assertEqual(str(variants[79]), '"INTERVAL \'-85349890\' YEAR"') # Check to_json on timestamps with custom timezones self.assertEqual( @@ -2223,6 +2453,31 @@ def test_variant_type(self): tzinfo=datetime.timezone(datetime.timedelta(hours=23, minutes=2, seconds=22)), ), ) + # For day time intervals, the success of the str() tests proves that the microseconds + # are being extracted correctly. Therefore, not all of the cases need to be verified for + # toPython + self.assertEqual( + variants[15].toPython(), + datetime.timedelta(microseconds=9223372036854775807), + ) + self.assertEqual( + variants[29].toPython(), + datetime.timedelta(microseconds=-9223372036854775800), + ) + self.assertEqual( + variants[42].toPython(), + datetime.timedelta(microseconds=0), + ) + self.assertEqual( + variants[54].toPython(), + datetime.timedelta(microseconds=5621529600000000), + ) + self.assertEqual( + variants[57].toPython(), + datetime.timedelta(microseconds=-324223232000000), + ) + + self.assertRaises(PySparkNotImplementedError, lambda: variants[65].toPython()) # check repr self.assertEqual(str(variants[0]), str(eval(repr(variants[0])))) diff --git a/python/pyspark/sql/variant_utils.py b/python/pyspark/sql/variant_utils.py index 40cc69c1f0961..95df421c0d493 100644 --- a/python/pyspark/sql/variant_utils.py +++ b/python/pyspark/sql/variant_utils.py @@ -21,8 +21,12 @@ import json import struct from array import array +from decimal import Decimal from typing import Any, Callable, Dict, List, Tuple -from pyspark.errors import PySparkValueError +from pyspark.errors import ( + PySparkNotImplementedError, + PySparkValueError, +) from zoneinfo import ZoneInfo @@ -107,6 +111,12 @@ class VariantUtils: # Long string value. The content is (4-byte little-endian unsigned integer representing the # string size) + (size bytes of string content). LONG_STR = 16 + # year-month interval value. The content is one byte representing the start and end field values + # (1 bit each starting at least significant bits) and a 4-byte little-endian signed integer + YEAR_MONTH_INTERVAL = 19 + # day-time interval value. The content is one byte representing the start and end field values + # (2 bits each starting at least significant bits) and an 8-byte little-endian signed integer + DAY_TIME_INTERVAL = 20 U32_SIZE = 4 @@ -122,6 +132,11 @@ class VariantUtils: MAX_DECIMAL16_PRECISION = 38 MAX_DECIMAL16_VALUE = 10**MAX_DECIMAL16_PRECISION + # There is no PySpark equivalent of the SQL year-month interval type. This class acts as a + # placeholder for this type + class _PlaceholderYearMonthIntervalInternalType: + pass + @classmethod def to_json(cls, value: bytes, metadata: bytes, zone_id: str = "UTC") -> str: """ @@ -160,6 +175,32 @@ def _get_type_info(cls, value: bytes, pos: int) -> Tuple[int, int]: type_info = (value[pos] >> VariantUtils.BASIC_TYPE_BITS) & VariantUtils.TYPE_INFO_MASK return (basic_type, type_info) + @classmethod + def _get_day_time_interval_fields(cls, value: bytes, pos: int) -> Tuple[int, int]: + """ + Returns the (start_field, end_field) pair for a variant representing a day-time interval + value stored at a given position in the value. + """ + cls._check_index(pos, len(value)) + start_field = value[pos] & 0x3 + end_field = (value[pos] >> 2) & 0x3 + if end_field < start_field: + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) + return (start_field, end_field) + + @classmethod + def _get_year_month_interval_fields(cls, value: bytes, pos: int) -> Tuple[int, int]: + """ + Returns the (start_field, end_field) paid for a variant representing a year-month interval + value stored at a given position in the value. + """ + cls._check_index(pos, len(value)) + start_field = value[pos] & 0x1 + end_field = (value[pos] >> 1) & 0x1 + if end_field < start_field: + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) + return (start_field, end_field) + @classmethod def _get_metadata_key(cls, metadata: bytes, id: int) -> str: """ @@ -235,6 +276,38 @@ def _get_timestamp(cls, value: bytes, pos: int, zone_id: str) -> datetime.dateti ).astimezone(ZoneInfo(zone_id)) raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={}) + @classmethod + def _get_yminterval_info(cls, value: bytes, pos: int) -> Tuple[int, int, int]: + """ + Returns the (months, start_field, end_field) tuple from a year-month interval value at a + given position in a variant. + """ + cls._check_index(pos, len(value)) + basic_type, type_info = cls._get_type_info(value, pos) + if basic_type != VariantUtils.PRIMITIVE: + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) + if type_info == VariantUtils.YEAR_MONTH_INTERVAL: + months = cls._read_long(value, pos + 2, 4, signed=True) + start_field, end_field = cls._get_year_month_interval_fields(value, pos + 1) + return (months, start_field, end_field) + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) + + @classmethod + def _get_timedelta(cls, value: bytes, pos: int) -> Tuple[int, int, int]: + """ + Returns the (micros, start_field, end_field) tuple from a day-time interval value at a given + position in a variant. + """ + cls._check_index(pos, len(value)) + basic_type, type_info = cls._get_type_info(value, pos) + if basic_type != VariantUtils.PRIMITIVE: + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) + if type_info == VariantUtils.DAY_TIME_INTERVAL: + micros = cls._read_long(value, pos + 2, 8, signed=True) + start_field, end_field = cls._get_day_time_interval_fields(value, pos + 1) + return (micros, start_field, end_field) + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) + @classmethod def _get_string(cls, value: bytes, pos: int) -> str: cls._check_index(pos, len(value)) @@ -350,8 +423,139 @@ def _get_type(cls, value: bytes, pos: int) -> Any: return datetime.datetime elif type_info == VariantUtils.LONG_STR: return str + elif type_info == VariantUtils.DAY_TIME_INTERVAL: + return datetime.timedelta + elif type_info == VariantUtils.YEAR_MONTH_INTERVAL: + return cls._PlaceholderYearMonthIntervalInternalType raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={}) + @classmethod + def _to_year_month_interval_ansi_string(cls, months: int, start_field: int, end_field: int): + YEAR = 0 + MONTH = 1 + MONTHS_PER_YEAR = 12 + sign = "" + abs_months = months + if months < 0: + sign = "-" + abs_months = -abs_months + year = sign + str(abs_months // MONTHS_PER_YEAR) + year_and_month = year + "-" + str(abs_months % MONTHS_PER_YEAR) + format_builder = ["INTERVAL '"] + if start_field == end_field: + if start_field == YEAR: + format_builder.append(year + "' YEAR") + else: + format_builder.append(str(months) + "' MONTH") + else: + format_builder.append(year_and_month + "' YEAR TO MONTH") + return ''.join(format_builder) + + @classmethod + def _to_day_time_interval_ansi_string(cls, + micros: int, start_field: int, end_field: int) -> str: + DAY = 0 + HOUR = 1 + MINUTE = 2 + SECOND = 3 + MIN_LONG_VALUE = -9223372036854775808 + MAX_LONG_VALUE = 9223372036854775807 + MICROS_PER_SECOND = 1000 * 1000 + MICROS_PER_MINUTE = MICROS_PER_SECOND * 60 + MICROS_PER_HOUR = MICROS_PER_MINUTE * 60 + MICROS_PER_DAY = MICROS_PER_HOUR * 24 + MAX_SECOND = MAX_LONG_VALUE // MICROS_PER_SECOND + MAX_MINUTE = MAX_LONG_VALUE // MICROS_PER_MINUTE + MAX_HOUR = MAX_LONG_VALUE // MICROS_PER_HOUR + MAX_DAY = MAX_LONG_VALUE // MICROS_PER_DAY + + def field_to_string(field: int) -> str: + if field == DAY: + return "DAY" + elif field == HOUR: + return "HOUR" + elif field == MINUTE: + return "MINUTE" + elif field == SECOND: + return "SECOND" + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) + + if end_field < start_field: + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) + sign = "" + rest = micros + from_str = field_to_string(start_field).upper() + to_str = field_to_string(end_field).upper() + prefix = "INTERVAL '" + postfix = f"' {from_str}" if (start_field == end_field) else f"' {from_str} TO {to_str}" + if micros < 0: + if micros == MIN_LONG_VALUE: + # Especial handling of minimum `Long` value because negate op overflows `Long`. + # seconds = 106751991 * (24 * 60 * 60) + 4 * 60 * 60 + 54 = 9223372036854 + # microseconds = -9223372036854000000L-775808 == Long.MinValue + base_str = "-106751991 04:00:54.775808000" + first_str = "-" + ( + str(MAX_DAY) + if (start_field == DAY) + else ( + str(MAX_HOUR) + if (start_field == HOUR) + else ( + str(MAX_MINUTE) + if (start_field == MINUTE) + else str(MAX_SECOND) + ".775808" + ) + ) + ) + if start_field == end_field: + return prefix + first_str + postfix + else: + substr_start = ( + 10 if (start_field == DAY) else (13 if (start_field == HOUR) else 16) + ) + substr_end = ( + 13 if (end_field == HOUR) else (16 if (end_field == MINUTE) else 26) + ) + return prefix + first_str + base_str[substr_start:substr_end] + postfix + else: + sign = "-" + rest = -rest + format_builder = [sign] + format_args = [] + if start_field == DAY: + format_builder.append(str(rest // MICROS_PER_DAY)) + rest %= MICROS_PER_DAY + elif start_field == HOUR: + format_builder.append("%02d") + format_args.append(rest // MICROS_PER_HOUR) + rest %= MICROS_PER_HOUR + elif start_field == MINUTE: + format_builder.append("%02d") + format_args.append(rest // MICROS_PER_MINUTE) + rest %= MICROS_PER_MINUTE + elif start_field == SECOND: + lead_zero = "0" if (rest < 10 * MICROS_PER_SECOND) else "" + format_builder.append( + lead_zero + (Decimal(rest) / Decimal(1000000)).normalize().to_eng_string() + ) + + if start_field < HOUR and HOUR <= end_field: + format_builder.append(" %02d") + format_args.append(rest // MICROS_PER_HOUR) + rest %= MICROS_PER_HOUR + if start_field < MINUTE and MINUTE <= end_field: + format_builder.append(":%02d") + format_args.append(rest // MICROS_PER_MINUTE) + rest %= MICROS_PER_MINUTE + if start_field < SECOND and SECOND <= end_field: + lead_zero = "0" if (rest < 10 * MICROS_PER_SECOND) else "" + format_builder.append( + ":" + + lead_zero + + (Decimal(rest) / Decimal(1000000)).normalize().to_eng_string() + ) + return prefix + ("".join(format_builder) % tuple(format_args)) + postfix + @classmethod def _to_json(cls, value: bytes, metadata: bytes, pos: int, zone_id: str) -> str: variant_type = cls._get_type(value, pos) @@ -375,6 +579,13 @@ def handle_array(value_pos_list: List[int]) -> str: return "[" + ",".join(value_list) + "]" return cls._handle_array(value, pos, handle_array) + elif variant_type == datetime.timedelta: + micros, start_field, end_field = cls._get_timedelta(value, pos) + return '"' + cls._to_day_time_interval_ansi_string(micros, start_field, end_field) + '"' + elif variant_type == cls._PlaceholderYearMonthIntervalInternalType: + months, start_field, end_field = cls._get_yminterval_info(value, pos) + return '"' + cls._to_year_month_interval_ansi_string(months, start_field, + end_field) + '"' else: value = cls._get_scalar(variant_type, value, metadata, pos, zone_id) if value is None: @@ -412,6 +623,16 @@ def handle_array(value_pos_list: List[int]) -> List[Any]: return value_list return cls._handle_array(value, pos, handle_array) + elif variant_type == datetime.timedelta: + # day-time intervals are represented using timedelta in a trivial manner + return datetime.timedelta( + microseconds=cls._get_timedelta(value, pos)[0] + ) + elif variant_type == cls._PlaceholderYearMonthIntervalInternalType: + raise PySparkNotImplementedError( + errorClass="NOT_IMPLEMENTED", + messageParameters={"feature": "VariantUtils.YEAR_MONTH_INTERVAL"}, + ) else: return cls._get_scalar(variant_type, value, metadata, pos, zone_id="UTC") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala index d31af81424818..7682e4a048ecb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala @@ -104,6 +104,10 @@ object VariantExpressionEvalUtils { case DateType => builder.appendDate(input.asInstanceOf[Int]) case TimestampType => builder.appendTimestamp(input.asInstanceOf[Long]) case TimestampNTZType => builder.appendTimestampNtz(input.asInstanceOf[Long]) + case ymi: YearMonthIntervalType => + builder.appendYearMonthInterval(input.asInstanceOf[Int], ymi.startField, ymi.endField) + case dti: DayTimeIntervalType => + builder.appendDayTimeInterval(input.asInstanceOf[Long], dti.startField, dti.endField) case VariantType => val v = input.asInstanceOf[VariantVal] builder.appendVariant(new Variant(v.getValue, v.getMetadata)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index b80fb11b6813b..de7544381381a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeAnyCollation import org.apache.spark.sql.types._ import org.apache.spark.types.variant._ -import org.apache.spark.types.variant.VariantUtil.Type +import org.apache.spark.types.variant.VariantUtil.{IntervalFields, Type} import org.apache.spark.unsafe.types._ @@ -259,7 +259,7 @@ case object VariantGet { */ def checkDataType(dataType: DataType): Boolean = dataType match { case _: NumericType | BooleanType | _: StringType | BinaryType | _: DatetimeType | - VariantType => + VariantType | _: DayTimeIntervalType | _: YearMonthIntervalType => true case ArrayType(elementType, _) => checkDataType(elementType) case MapType(_: StringType, valueType, _) => checkDataType(valueType) @@ -353,9 +353,18 @@ case object VariantGet { case Type.TIMESTAMP_NTZ => Literal(v.getLong, TimestampNTZType) case Type.FLOAT => Literal(v.getFloat, FloatType) case Type.BINARY => Literal(v.getBinary, BinaryType) + case Type.YEAR_MONTH_INTERVAL => + val fields: IntervalFields = v.getYearMonthIntervalFields + Literal(v.getLong.toInt, YearMonthIntervalType(fields.startField, fields.endField)) + case Type.DAY_TIME_INTERVAL => + val fields: IntervalFields = v.getDayTimeIntervalFields + Literal(v.getLong, DayTimeIntervalType(fields.startField, fields.endField)) // We have handled other cases and should never reach here. This case is only intended // to by pass the compiler exhaustiveness check. - case _ => throw QueryExecutionErrors.unreachableError() + case _ => throw new SparkRuntimeException( + errorClass = "UNKNOWN_PRIMITIVE_TYPE_IN_VARIANT", + messageParameters = Map("id" -> v.getTypeInfo.toString) + ) } // We mostly use the `Cast` expression to implement the cast. However, `Cast` silently // ignores the overflow in the long/decimal -> timestamp cast, and we want to enforce @@ -695,6 +704,12 @@ object SchemaOfVariant { case Type.TIMESTAMP_NTZ => TimestampNTZType case Type.FLOAT => FloatType case Type.BINARY => BinaryType + case Type.YEAR_MONTH_INTERVAL => + val fields: IntervalFields = v.getYearMonthIntervalFields + YearMonthIntervalType(fields.startField, fields.endField) + case Type.DAY_TIME_INTERVAL => + val fields: IntervalFields = v.getDayTimeIntervalFields + DayTimeIntervalType(fields.startField, fields.endField) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala index a758fa84f6fca..fb0bf63c01123 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.variant -import java.time.{LocalDateTime, ZoneId, ZoneOffset} +import java.time.{Duration, LocalDateTime, Period, ZoneId, ZoneOffset} import scala.collection.mutable import scala.reflect.runtime.universe.TypeTag @@ -334,6 +334,96 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { testInvalidVariantGet("9223372036855", "$", TimestampType) testInvalidVariantGet("0", "$", TimestampNTZType) + // year-month interval corner + random cases + Seq(0, 2147483647, -2147483648, 4398201, -213494932).foreach(input => { + for (startField <- YearMonthIntervalType.YEAR to YearMonthIntervalType.MONTH) { + for (endField <- startField to YearMonthIntervalType.MONTH) { + val cleanInput = if (endField == 0) input / 12 else input + // numeric source + testVariantGet( + cleanInput.toString, + "$", + YearMonthIntervalType(startField.toByte, endField.toByte), + Cast( + Literal(cleanInput, IntegerType), + YearMonthIntervalType(startField.toByte, endField.toByte) + ).eval() + ) + // String source + testVariantGet( + "\"" + Cast(Cast( + Literal(cleanInput, IntegerType), + YearMonthIntervalType(startField.toByte, endField.toByte) + ), StringType).eval().toString + "\"", + "$", + YearMonthIntervalType(startField.toByte, endField.toByte), + Cast( + Literal(cleanInput, IntegerType), + YearMonthIntervalType(startField.toByte, endField.toByte) + ).eval() + ) + } + } + }) + + // When a variant is being cast to interval, the none of the newly written code is used + // (the cast to interval was simply enabled) and therefore, not all of the corner cases need + // to be tested. + + // day-time interval corner cases. In the string source examples, the corner cases are cast to + // interval followed by another cast to string followed by another cast to interval since + // the cast from string to interval loses information. + testVariantGet("9223372036854.775807", "$", DayTimeIntervalType(0, 3), + Cast(Literal(Decimal("9223372036854.775807")), DayTimeIntervalType(0, 3)).eval() + ) + testVariantGet( + "\"" + Cast(Cast(Literal(Decimal("9223372036854.775807")), DayTimeIntervalType(0, 3)), + StringType).eval().toString + "\"", "$", DayTimeIntervalType(0, 3), + Cast( + Cast( + Cast(Literal(Decimal("9223372036854.775807")), DayTimeIntervalType(0, 3)), + StringType + ), + DayTimeIntervalType(0, 3) + ).eval() + ) + testVariantGet("-153722867280.912930", "$", DayTimeIntervalType(1, 2), + Cast(Literal(Decimal("-153722867280.912930")), DayTimeIntervalType(1, 2)).eval() + ) + testVariantGet( + "\"" + Cast(Cast(Literal(Decimal("-153722867280.912930")), DayTimeIntervalType(1, 2)), + StringType).eval().toString + "\"", "$", DayTimeIntervalType(1, 2), + Cast( + Cast( + Cast(Literal(Decimal("-153722867280.912930")), DayTimeIntervalType(1, 2)), + StringType + ), + DayTimeIntervalType(1, 2) + ).eval() + ) + testVariantGet("-2562047788.015215", "$", DayTimeIntervalType(0, 1), + Cast(Literal(Decimal("-2562047788.015215")), DayTimeIntervalType(0, 1)).eval() + ) + testVariantGet( + "\"" + Cast(Cast(Literal(Decimal("-2562047788.015215")), DayTimeIntervalType(0, 1)), + StringType).eval().toString + "\"", "$", DayTimeIntervalType(0, 1), + Cast( + Cast(Cast(Literal(Decimal("-2562047788.015215")), DayTimeIntervalType(0, 1)), StringType), + DayTimeIntervalType(0, 1) + ).eval() + ) + testVariantGet("-106751991.167300", "$", DayTimeIntervalType(0, 0), + Cast(Literal(Decimal("-106751991.167300")), DayTimeIntervalType(0, 0)).eval() + ) + testVariantGet( + "\"" + Cast(Cast(Literal(Decimal("-106751991.167300")), DayTimeIntervalType(0, 0)), + StringType).eval().toString + "\"", "$", DayTimeIntervalType(0, 0), + Cast( + Cast(Cast(Literal(Decimal("-106751991.167300")), DayTimeIntervalType(0, 0)), StringType), + DayTimeIntervalType(0, 0) + ).eval() + ) + // Source type is double. Always use scientific notation to avoid decimal. testVariantGet("1E0", "$", BooleanType, true) testVariantGet("0E0", "$", BooleanType, false) @@ -710,6 +800,14 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StructsToJson(Map.empty, input), expected) } + def checkToJsonFail(value: Array[Byte], id: Int): Unit = { + val input = Literal(new VariantVal(value, emptyMetadata)) + checkErrorInExpression[SparkRuntimeException]( + ResolveTimeZone.resolveTimeZones(StructsToJson(Map.empty, input)), + "UNKNOWN_PRIMITIVE_TYPE_IN_VARIANT", Map("id" -> id.toString) + ) + } + def checkCast(value: Array[Byte], dataType: DataType, expected: Any): Unit = { val input = Literal(new VariantVal(value, emptyMetadata)) checkEvaluation(Cast(input, dataType, evalMode = EvalMode.ANSI), expected) @@ -727,6 +825,37 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkCast(Array(primitiveHeader(DATE), 1, 0, 0, 0), TimestampType, MICROS_PER_DAY + 8 * MICROS_PER_HOUR) } + // corner + random cases + Seq(0, 2147483647, -2147483648, 345344843, -4357342).foreach(input => { + for (startField <- YearMonthIntervalType.YEAR to YearMonthIntervalType.MONTH) { + for (endField <- startField to YearMonthIntervalType.MONTH) { + val headerByte = startField | (endField << 1) + checkToJson(Array(primitiveHeader(YEAR_MONTH_INTERVAL), headerByte.toByte, + (input & 0xFF).toByte, ((input >> 8) & 0xFF).toByte, + ((input >> 16) & 0xFF).toByte, ((input >> 24) & 0xFF).toByte), + "\"" + Literal(input, YearMonthIntervalType(startField.toByte, endField.toByte)) + + "\"") + } + } + }) + + // corner + random cases + Seq(0L, 9223372036854775807L, -9223372036854775808L, 2374234381L, -23467681L).foreach(input => { + for (startField <- DayTimeIntervalType.DAY to DayTimeIntervalType.SECOND) { + for (endField <- startField to DayTimeIntervalType.SECOND) { + val headerByte = startField | (endField << 2) + checkToJson(Array(primitiveHeader(DAY_TIME_INTERVAL), headerByte.toByte, + (input & 0xFF).toByte, ((input >> 8) & 0xFF).toByte, + ((input >> 16) & 0xFF).toByte, ((input >> 24) & 0xFF).toByte, + ((input >> 32) & 0xFF).toByte, ((input >> 40) & 0xFF).toByte, + ((input >> 48) & 0xFF).toByte, ((input >> 56) & 0xFF).toByte), + "\"" + Literal(input, DayTimeIntervalType(startField.toByte, endField.toByte)) + + "\"") + } + } + }) + + checkToJsonFail(Array(primitiveHeader(25)), 25) def littleEndianLong(value: Long): Array[Byte] = BigInt(value).toByteArray.reverse.padTo(8, 0.toByte) @@ -839,6 +968,51 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { check("x" * 128, "\"" + ("x" * 128) + "\"") check(Array[Byte](1, 2, 3), "\"AQID\"") check(Literal(0, DateType), "\"1970-01-01\"") + + // year-month interval corner + random cases + Seq(0, 2147483647, -2147483648, 753992, -5920283).foreach(input => { + for (startField <- YearMonthIntervalType.YEAR to YearMonthIntervalType.MONTH) { + for (endField <- startField to YearMonthIntervalType.MONTH) { + val lit = Literal(input, YearMonthIntervalType(startField.toByte, endField.toByte)) + check(lit, "\"" + lit.toString + "\"") + } + } + }) + // Size of YMInterval + assert(Cast(Literal.create(Period.ofMonths(0)), VariantType, evalMode = EvalMode.ANSI) + .eval().asInstanceOf[VariantVal].getValue.length == 6) + + // Array of year-month intervals + val ymArrLit = Literal.create( + Array(Period.ofMonths(0), Period.ofMonths(2147483647), Period.ofMonths(-2147483647)), + ArrayType(YearMonthIntervalType(1, 1)) + ) + check(ymArrLit, """["INTERVAL '0' MONTH","INTERVAL""" + + """ '2147483647' MONTH","INTERVAL '-2147483647' MONTH"]""") + + // day-time interval corner + random cases + Seq(0L, 9223372036854775807L, -9223372036854775808L, 47356878948217L, -23745867989934789L) + .foreach(input => { + for (startField <- DayTimeIntervalType.DAY to DayTimeIntervalType.SECOND) { + for (endField <- startField to DayTimeIntervalType.SECOND) { + val lit = Literal(input, DayTimeIntervalType(startField.toByte, endField.toByte)) + check(lit, "\"" + lit.toString + "\"") + } + } + }) + // Size of DTInterval + assert(Cast(Literal.create(Duration.ofSeconds(0)), VariantType, evalMode = EvalMode.ANSI) + .eval().asInstanceOf[VariantVal].getValue.length == 10) + + // Array of day-time intervals + val dtArrLit = Literal.create( + Array(Duration.ofSeconds(0), Duration.ofSeconds(9223372036854L), + Duration.ofSeconds(-9223372036854L)), + ArrayType(DayTimeIntervalType(3, 3)) + ) + check(dtArrLit, """["INTERVAL '00' SECOND","INTERVAL""" + + """ '9223372036854' SECOND","INTERVAL '-9223372036854' SECOND"]""") + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { check(Literal(0L, TimestampType), "\"1970-01-01 00:00:00+00:00\"") check(Literal(0L, TimestampNTZType), "\"1970-01-01 00:00:00\"") @@ -862,6 +1036,34 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { check(struct, """{"a":{"i":0},"b":{"a":"123","b":"true","c":"f"},"c":["123","true","f"]}""") } + test("schema_of_variant - unknown type") { + val emptyMetadata = Array[Byte](VERSION, 0, 0) + + def checkErrorInSchemaOf(value: Array[Byte], id: Int): Unit = { + val input = Literal(new VariantVal(value, emptyMetadata)) + checkErrorInExpression[SparkRuntimeException]( + ResolveTimeZone.resolveTimeZones(SchemaOfVariant(input).replacement), + "UNKNOWN_PRIMITIVE_TYPE_IN_VARIANT", Map("id" -> id.toString) + ) + } + checkErrorInSchemaOf(Array(primitiveHeader(25)), 25) + } + + test("malformed interval type") { + val emptyMetadata = Array[Byte](VERSION, 0, 0) + + def checkErrorInIntervalVariant(value: Array[Byte], id: Int): Unit = { + val input = Literal(new VariantVal(value, emptyMetadata)) + checkErrorInExpression[SparkRuntimeException]( + ResolveTimeZone.resolveTimeZones(StructsToJson(Map.empty, input)), + "MALFORMED_VARIANT") + } + checkErrorInIntervalVariant(Array(primitiveHeader(YEAR_MONTH_INTERVAL), 0, 0, 0, 0), + YEAR_MONTH_INTERVAL) + checkErrorInIntervalVariant(Array(primitiveHeader(DAY_TIME_INTERVAL), 0, 0, 0, 0, 0, 0, 0, 0), + DAY_TIME_INTERVAL) + } + test("schema_of_variant - schema merge") { val nul = Literal(null, StringType) val boolean = Literal.default(BooleanType) @@ -878,8 +1080,15 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { val array2 = Literal(Array(0.0)) val struct1 = Literal.default(StructType.fromDDL("a string")) val struct2 = Literal.default(StructType.fromDDL("a boolean, b bigint")) + // TypeCoercion.findTightestCommonType handles interval types in expected ways. It doesn't make + // sense to add merge intervals with other types + val dtInterval1 = Literal(0L, DayTimeIntervalType(1, 3)) + val dtInterval2 = Literal(0L, DayTimeIntervalType(0, 2)) + val ymInterval1 = Literal(0, YearMonthIntervalType(0, 0)) + val ymInterval2 = Literal(0, YearMonthIntervalType(1, 1)) val inputs = Seq(nul, boolean, long, string, double, date, timestamp, timestampNtz, float, - binary, decimal, array1, array2, struct1, struct2) + binary, decimal, array1, array2, struct1, struct2, dtInterval1, dtInterval2, ymInterval1, + ymInterval2) val results = mutable.HashMap.empty[(Literal, Literal), String] for (i <- inputs) { @@ -898,6 +1107,8 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { results.put((float, decimal), "DOUBLE") results.put((array1, array2), "ARRAY") results.put((struct1, struct2), "STRUCT") + results.put((dtInterval1, dtInterval2), "INTERVAL DAY TO SECOND") + results.put((ymInterval1, ymInterval2), "INTERVAL YEAR TO MONTH") for (i1 <- inputs) { for (i2 <- inputs) { diff --git a/streaming/pom.xml b/streaming/pom.xml index 85a4d268d2a25..4deee019593b8 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -7,7 +7,7 @@ ~ (the "License"); you may not use this file except in compliance with ~ the License. You may obtain a copy of the License at ~ - ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ http://www.apache.g/licenses/LICENSE-2.0 ~ ~ Unless required by applicable law or agreed to in writing, software ~ distributed under the License is distributed on an "AS IS" BASIS, From be0bfc2aff6ddef7ba6d6252fd9c0dc72e222c70 Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Wed, 24 Jul 2024 13:02:48 -0700 Subject: [PATCH 02/17] minor fix --- streaming/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/pom.xml b/streaming/pom.xml index 4deee019593b8..85a4d268d2a25 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -7,7 +7,7 @@ ~ (the "License"); you may not use this file except in compliance with ~ the License. You may obtain a copy of the License at ~ - ~ http://www.apache.g/licenses/LICENSE-2.0 + ~ http://www.apache.org/licenses/LICENSE-2.0 ~ ~ Unless required by applicable law or agreed to in writing, software ~ distributed under the License is distributed on an "AS IS" BASIS, From 2e0fc0701914d291a9353758d87bcf9d80b673fe Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Wed, 24 Jul 2024 13:37:49 -0700 Subject: [PATCH 03/17] fix --- .../main/java/org/apache/spark/types/variant/Variant.java | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java index ecb25c51aa0d1..882ae538599ad 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java @@ -343,12 +343,8 @@ static void toJsonImpl(byte[] value, byte[] metadata, int pos, StringBuilder sb, case DAY_TIME_INTERVAL: IntervalFields dtFields = VariantUtil.getDayTimeIntervalFields(value, pos); long dtValue = VariantUtil.getLong(value, pos); - try { - appendQuoted(sb, DayTimeIntervalUtils.toDayTimeIntervalANSIString(dtValue, - dtFields.startField, dtFields.endField)); - } catch(Exception e) { - throw malformedVariant(); - } + appendQuoted(sb, DayTimeIntervalUtils.toDayTimeIntervalANSIString(dtValue, + dtFields.startField, dtFields.endField)); break; } } From f25b2f65e1effd0b794439f1d349504cf7fc474e Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Wed, 24 Jul 2024 13:39:51 -0700 Subject: [PATCH 04/17] fix --- .../main/java/org/apache/spark/types/variant/VariantUtil.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java index 3177c81223186..74a4afee5e4be 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java @@ -417,8 +417,8 @@ public IntervalFields(byte startField, byte endField) { this.endField = endField; } - public byte startField; - public byte endField; + public final byte startField; + public final byte endField; } // Get the start and end fields of a variant value representing a year-month interval value. The From b7fe993b7c9ccb6cfc05063f60768ede6bf5141d Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Wed, 24 Jul 2024 16:16:23 -0700 Subject: [PATCH 05/17] fixed Gene's suggestions --- .../resources/error/error-conditions.json | 2 +- .../spark/util/DayTimeIntervalUtils.java | 20 ++++++++ .../spark/util/YearMonthIntervalUtils.java | 17 +++++++ common/variant/README.md | 46 +++++++++---------- .../apache/spark/types/variant/Variant.java | 8 +++- .../spark/types/variant/VariantBuilder.java | 4 +- .../spark/types/variant/VariantUtil.java | 10 ++-- 7 files changed, 76 insertions(+), 31 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index a6f22de5eebe5..36b9fccaba404 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -4188,7 +4188,7 @@ }, "UNKNOWN_PRIMITIVE_TYPE_IN_VARIANT" : { "message" : [ - "Unknown primitive type with id was found in a variant value. The type might be supported in a newer version." + "Unknown primitive type with id was found in a variant value." ], "sqlState" : "22023" }, diff --git a/common/utils/src/main/scala/org/apache/spark/util/DayTimeIntervalUtils.java b/common/utils/src/main/scala/org/apache/spark/util/DayTimeIntervalUtils.java index 83c4d0375b6c4..554eb8a154a46 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/DayTimeIntervalUtils.java +++ b/common/utils/src/main/scala/org/apache/spark/util/DayTimeIntervalUtils.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.util; import org.apache.spark.SparkException; @@ -41,6 +58,9 @@ public static String fieldToString(byte field) throws SparkException { } } + // Used to convert microseconds representing a day-time interval with a given start and end field + // to its ANSI SQL string representation. Throws a SparkException if startField or endField are + // out of bounds. public static String toDayTimeIntervalANSIString(long micros, byte startField, byte endField) throws SparkException { String sign = ""; diff --git a/common/utils/src/main/scala/org/apache/spark/util/YearMonthIntervalUtils.java b/common/utils/src/main/scala/org/apache/spark/util/YearMonthIntervalUtils.java index 84d8bd25090c1..9d867743c018e 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/YearMonthIntervalUtils.java +++ b/common/utils/src/main/scala/org/apache/spark/util/YearMonthIntervalUtils.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.util; // Replicating code from SparkIntervalUtils so code in the 'common' space can work with diff --git a/common/variant/README.md b/common/variant/README.md index 12c14361941cd..3b62ab1bac10f 100644 --- a/common/variant/README.md +++ b/common/variant/README.md @@ -335,29 +335,29 @@ The Decimal type contains a scale, but no precision. The implied precision of a | Object | `2` | A collection of (string-key, variant-value) pairs | | Array | `3` | An ordered sequence of variant values | -| Primitive Type | Type ID | Equivalent Parquet Type | Binary format | -|-----------------------------|---------|-----------------------------------------------|---------------------------------------------------------------------------------------------------------------------| -| null | `0` | any | none | -| boolean (True) | `1` | BOOLEAN | none | -| boolean (False) | `2` | BOOLEAN | none | -| int8 | `3` | INT(8, signed) | 1 byte | -| int16 | `4` | INT(16, signed) | 2 byte little-endian | -| int32 | `5` | INT(32, signed) | 4 byte little-endian | -| int64 | `6` | INT(64, signed) | 8 byte little-endian | -| double | `7` | DOUBLE | IEEE little-endian | -| decimal4 | `8` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | -| decimal8 | `9` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | -| decimal16 | `10` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | -| date | `11` | DATE | 4 byte little-endian | -| timestamp | `12` | TIMESTAMP(true, MICROS) | 8-byte little-endian | -| timestamp without time zone | `13` | TIMESTAMP(false, MICROS) | 8-byte little-endian | -| float | `14` | FLOAT | IEEE little-endian | -| binary | `15` | BINARY | 4 byte little-endian size, followed by bytes | -| string | `16` | STRING | 4 byte little-endian size, followed by UTF-8 encoded bytes | -| binary from metadata | `17` | BINARY | Little-endian index into the metadata dictionary. Number of bytes is equal to the metadata `offset_size`. | -| string from metadata | `18` | STRING | Little-endian index into the metadata dictionary. Number of bytes is equal to the metadata `offset_size`. | -| year-month interval | `19` | YearMonthIntervalType(start_field, end_field) | 1 byte denoting start field (1 bit) and end field (1 bit) starting at LSB followed by 4-byte little-endian value. | -| day-time interval | `20` | DayTimeIntervalType(start_field, end_field) | 1 byte denoting start field (2 bits) and end field (2 bits) starting at LSB followed by 8-byte little-endian value. | +| Primitive Type | Type ID | Equivalent Parquet Type | Binary format | +|-----------------------------|---------|---------------------------|---------------------------------------------------------------------------------------------------------------------| +| null | `0` | any | none | +| boolean (True) | `1` | BOOLEAN | none | +| boolean (False) | `2` | BOOLEAN | none | +| int8 | `3` | INT(8, signed) | 1 byte | +| int16 | `4` | INT(16, signed) | 2 byte little-endian | +| int32 | `5` | INT(32, signed) | 4 byte little-endian | +| int64 | `6` | INT(64, signed) | 8 byte little-endian | +| double | `7` | DOUBLE | IEEE little-endian | +| decimal4 | `8` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | +| decimal8 | `9` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | +| decimal16 | `10` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | +| date | `11` | DATE | 4 byte little-endian | +| timestamp | `12` | TIMESTAMP(true, MICROS) | 8-byte little-endian | +| timestamp without time zone | `13` | TIMESTAMP(false, MICROS) | 8-byte little-endian | +| float | `14` | FLOAT | IEEE little-endian | +| binary | `15` | BINARY | 4 byte little-endian size, followed by bytes | +| string | `16` | STRING | 4 byte little-endian size, followed by UTF-8 encoded bytes | +| binary from metadata | `17` | BINARY | Little-endian index into the metadata dictionary. Number of bytes is equal to the metadata `offset_size`. | +| string from metadata | `18` | STRING | Little-endian index into the metadata dictionary. Number of bytes is equal to the metadata `offset_size`. | +| year-month interval | `19` | | 1 byte denoting start field (1 bit) and end field (1 bit) starting at LSB followed by 4-byte little-endian value. | +| day-time interval | `20` | | 1 byte denoting start field (2 bits) and end field (2 bits) starting at LSB followed by 8-byte little-endian value. | | Decimal Precision | Decimal value type | |-----------------------|--------------------| diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java index 882ae538599ad..ecb25c51aa0d1 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java @@ -343,8 +343,12 @@ static void toJsonImpl(byte[] value, byte[] metadata, int pos, StringBuilder sb, case DAY_TIME_INTERVAL: IntervalFields dtFields = VariantUtil.getDayTimeIntervalFields(value, pos); long dtValue = VariantUtil.getLong(value, pos); - appendQuoted(sb, DayTimeIntervalUtils.toDayTimeIntervalANSIString(dtValue, - dtFields.startField, dtFields.endField)); + try { + appendQuoted(sb, DayTimeIntervalUtils.toDayTimeIntervalANSIString(dtValue, + dtFields.startField, dtFields.endField)); + } catch(Exception e) { + throw malformedVariant(); + } break; } } diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java index f603b787da12d..f5e5f729459f7 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java @@ -217,7 +217,7 @@ public void appendTimestampNtz(long microsSinceEpoch) { public void appendYearMonthInterval(long value, byte startField, byte endField) { checkCapacity(1 + 5); writeBuffer[writePos++] = primitiveHeader(YEAR_MONTH_INTERVAL); - writeBuffer[writePos++] = (byte) (startField | (endField << 1)); + writeBuffer[writePos++] = (byte) ((startField & 0x1) | ((endField & 0x1) << 1)); writeLong(writeBuffer, writePos, value, 4); writePos += 4; } @@ -225,7 +225,7 @@ public void appendYearMonthInterval(long value, byte startField, byte endField) public void appendDayTimeInterval(long value, byte startField, byte endField) { checkCapacity(1 + 9); writeBuffer[writePos++] = primitiveHeader(DAY_TIME_INTERVAL); - writeBuffer[writePos++] = (byte) (startField | (endField << 2)); + writeBuffer[writePos++] = (byte) ((startField & 0x3) | ((endField & 0x3) << 2)); writeLong(writeBuffer, writePos, value, 8); writePos += 8; } diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java index 74a4afee5e4be..ac856fb57890d 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java @@ -380,9 +380,12 @@ public static boolean getBoolean(byte[] value, int pos) { // Get a long value from variant value `value[pos...]`. // It is only legal to call it if `getType` returns one of `Type.LONG/DATE/TIMESTAMP/ - // TIMESTAMP_NTZ`. If the type is `DATE`, the return value is guaranteed to fit into an int and - // represents the number of days from the Unix epoch. If the type is `TIMESTAMP/TIMESTAMP_NTZ`, - // the return value represents the number of microseconds from the Unix epoch. + // TIMESTAMP_NTZ/YEAR_MONTH_INTERVAL/DAY_TIME_INTERVAL`. If the type is `DATE`, the return value + // is guaranteed to fit into an int and represents the number of days from the Unix epoch. + // If the type is `TIMESTAMP/TIMESTAMP_NTZ`, the return value represents the number of + // microseconds from the Unix epoch. If the type is `YEAR_MONTH_INTERVAL`, the return value + // represents the number of months in the interval. If the type is `DAY_TIME_INTERVAL`, the + // return value represents the number of microseconds in the interval. // Throw `MALFORMED_VARIANT` if the variant is malformed. public static long getLong(byte[] value, int pos) { checkIndex(pos, value.length); @@ -411,6 +414,7 @@ public static long getLong(byte[] value, int pos) { } } + // Class used to pass around start and end fields of year-month and day-time interval values. public static class IntervalFields { public IntervalFields(byte startField, byte endField) { this.startField = startField; From 59d8572a27497b7b9913f1bb6a715c4464d03c51 Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Wed, 24 Jul 2024 16:38:33 -0700 Subject: [PATCH 06/17] added parquet equivalent for interval types in the readme --- common/variant/README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/common/variant/README.md b/common/variant/README.md index 3b62ab1bac10f..5d81fd6071274 100644 --- a/common/variant/README.md +++ b/common/variant/README.md @@ -356,8 +356,8 @@ The Decimal type contains a scale, but no precision. The implied precision of a | string | `16` | STRING | 4 byte little-endian size, followed by UTF-8 encoded bytes | | binary from metadata | `17` | BINARY | Little-endian index into the metadata dictionary. Number of bytes is equal to the metadata `offset_size`. | | string from metadata | `18` | STRING | Little-endian index into the metadata dictionary. Number of bytes is equal to the metadata `offset_size`. | -| year-month interval | `19` | | 1 byte denoting start field (1 bit) and end field (1 bit) starting at LSB followed by 4-byte little-endian value. | -| day-time interval | `20` | | 1 byte denoting start field (2 bits) and end field (2 bits) starting at LSB followed by 8-byte little-endian value. | +| year-month interval | `19` | INT(32, signed)* | 1 byte denoting start field (1 bit) and end field (1 bit) starting at LSB followed by 4-byte little-endian value. | +| day-time interval | `20` | INT(64, signed)* | 1 byte denoting start field (2 bits) and end field (2 bits) starting at LSB followed by 8-byte little-endian value. | | Decimal Precision | Decimal value type | |-----------------------|--------------------| @@ -368,6 +368,8 @@ The Decimal type contains a scale, but no precision. The implied precision of a The year-month and day-time interval types have one byte at the beginning indicating the start and end fields. In the case of the year-month interval, the least significant bit denotes the start field and the next least significant bit denotes the end field. The remaining 6 bits are unused. A field value of 0 represents YEAR and 1 represents MONTH. In the case of the day-time interval, the least significant 2 bits denote the start field and the next least significant 2 bits denote the end field. The remaining 4 bits are unused. A field value of 0 represents DAY, 1 represents HOUR, 2 represents MINUTE, and 3 represents SECOND. +The parquet format does not have pure equivalents for the year-month and day-time interval types. Year-month intervals are usually represented using int32 values and the day-time intervals are usually represented using int64 values. However, these values don't include the start and end fields of these types. Therefore, Spark stores them in the column metadata. + # Field ID order and uniqueness For objects, field IDs and offsets must be listed in the order of the corresponding field names, sorted lexicographically. Note that the fields themselves are not required to follow this order. As a result, offsets will not necessarily be listed in ascending order. From f073b755499aa200546acc0f32564ef3892bf375 Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Wed, 24 Jul 2024 16:46:09 -0700 Subject: [PATCH 07/17] minor change --- common/variant/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/variant/README.md b/common/variant/README.md index 5d81fd6071274..d18e958057059 100644 --- a/common/variant/README.md +++ b/common/variant/README.md @@ -368,7 +368,7 @@ The Decimal type contains a scale, but no precision. The implied precision of a The year-month and day-time interval types have one byte at the beginning indicating the start and end fields. In the case of the year-month interval, the least significant bit denotes the start field and the next least significant bit denotes the end field. The remaining 6 bits are unused. A field value of 0 represents YEAR and 1 represents MONTH. In the case of the day-time interval, the least significant 2 bits denote the start field and the next least significant 2 bits denote the end field. The remaining 4 bits are unused. A field value of 0 represents DAY, 1 represents HOUR, 2 represents MINUTE, and 3 represents SECOND. -The parquet format does not have pure equivalents for the year-month and day-time interval types. Year-month intervals are usually represented using int32 values and the day-time intervals are usually represented using int64 values. However, these values don't include the start and end fields of these types. Therefore, Spark stores them in the column metadata. +*The parquet format does not have pure equivalents for the year-month and day-time interval types. Year-month intervals are usually represented using int32 values and the day-time intervals are usually represented using int64 values. However, these values don't include the start and end fields of these types. Therefore, Spark stores them in the column metadata. # Field ID order and uniqueness From e727fb1db5f9d71adeab5a3cb1acdbf1e6939443 Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Wed, 24 Jul 2024 16:47:59 -0700 Subject: [PATCH 08/17] minor change --- common/variant/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/common/variant/README.md b/common/variant/README.md index d18e958057059..50d645cdb5e7e 100644 --- a/common/variant/README.md +++ b/common/variant/README.md @@ -356,8 +356,8 @@ The Decimal type contains a scale, but no precision. The implied precision of a | string | `16` | STRING | 4 byte little-endian size, followed by UTF-8 encoded bytes | | binary from metadata | `17` | BINARY | Little-endian index into the metadata dictionary. Number of bytes is equal to the metadata `offset_size`. | | string from metadata | `18` | STRING | Little-endian index into the metadata dictionary. Number of bytes is equal to the metadata `offset_size`. | -| year-month interval | `19` | INT(32, signed)* | 1 byte denoting start field (1 bit) and end field (1 bit) starting at LSB followed by 4-byte little-endian value. | -| day-time interval | `20` | INT(64, signed)* | 1 byte denoting start field (2 bits) and end field (2 bits) starting at LSB followed by 8-byte little-endian value. | +| year-month interval | `19` | INT(32, signed) [*] | 1 byte denoting start field (1 bit) and end field (1 bit) starting at LSB followed by 4-byte little-endian value. | +| day-time interval | `20` | INT(64, signed) [*] | 1 byte denoting start field (2 bits) and end field (2 bits) starting at LSB followed by 8-byte little-endian value. | | Decimal Precision | Decimal value type | |-----------------------|--------------------| @@ -368,7 +368,7 @@ The Decimal type contains a scale, but no precision. The implied precision of a The year-month and day-time interval types have one byte at the beginning indicating the start and end fields. In the case of the year-month interval, the least significant bit denotes the start field and the next least significant bit denotes the end field. The remaining 6 bits are unused. A field value of 0 represents YEAR and 1 represents MONTH. In the case of the day-time interval, the least significant 2 bits denote the start field and the next least significant 2 bits denote the end field. The remaining 4 bits are unused. A field value of 0 represents DAY, 1 represents HOUR, 2 represents MINUTE, and 3 represents SECOND. -*The parquet format does not have pure equivalents for the year-month and day-time interval types. Year-month intervals are usually represented using int32 values and the day-time intervals are usually represented using int64 values. However, these values don't include the start and end fields of these types. Therefore, Spark stores them in the column metadata. +[*] The parquet format does not have pure equivalents for the year-month and day-time interval types. Year-month intervals are usually represented using int32 values and the day-time intervals are usually represented using int64 values. However, these values don't include the start and end fields of these types. Therefore, Spark stores them in the column metadata. # Field ID order and uniqueness From fa1b481583ae31877a42a3e173563ba58a37594a Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Wed, 24 Jul 2024 17:11:53 -0700 Subject: [PATCH 09/17] minor change --- common/variant/README.md | 48 ++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/common/variant/README.md b/common/variant/README.md index 50d645cdb5e7e..391815dabf99f 100644 --- a/common/variant/README.md +++ b/common/variant/README.md @@ -335,29 +335,29 @@ The Decimal type contains a scale, but no precision. The implied precision of a | Object | `2` | A collection of (string-key, variant-value) pairs | | Array | `3` | An ordered sequence of variant values | -| Primitive Type | Type ID | Equivalent Parquet Type | Binary format | -|-----------------------------|---------|---------------------------|---------------------------------------------------------------------------------------------------------------------| -| null | `0` | any | none | -| boolean (True) | `1` | BOOLEAN | none | -| boolean (False) | `2` | BOOLEAN | none | -| int8 | `3` | INT(8, signed) | 1 byte | -| int16 | `4` | INT(16, signed) | 2 byte little-endian | -| int32 | `5` | INT(32, signed) | 4 byte little-endian | -| int64 | `6` | INT(64, signed) | 8 byte little-endian | -| double | `7` | DOUBLE | IEEE little-endian | -| decimal4 | `8` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | -| decimal8 | `9` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | -| decimal16 | `10` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | -| date | `11` | DATE | 4 byte little-endian | -| timestamp | `12` | TIMESTAMP(true, MICROS) | 8-byte little-endian | -| timestamp without time zone | `13` | TIMESTAMP(false, MICROS) | 8-byte little-endian | -| float | `14` | FLOAT | IEEE little-endian | -| binary | `15` | BINARY | 4 byte little-endian size, followed by bytes | -| string | `16` | STRING | 4 byte little-endian size, followed by UTF-8 encoded bytes | -| binary from metadata | `17` | BINARY | Little-endian index into the metadata dictionary. Number of bytes is equal to the metadata `offset_size`. | -| string from metadata | `18` | STRING | Little-endian index into the metadata dictionary. Number of bytes is equal to the metadata `offset_size`. | -| year-month interval | `19` | INT(32, signed) [*] | 1 byte denoting start field (1 bit) and end field (1 bit) starting at LSB followed by 4-byte little-endian value. | -| day-time interval | `20` | INT(64, signed) [*] | 1 byte denoting start field (2 bits) and end field (2 bits) starting at LSB followed by 8-byte little-endian value. | +| Primitive Type | Type ID | Equivalent Parquet Type | Binary format | +|-----------------------------|---------|-----------------------------|---------------------------------------------------------------------------------------------------------------------| +| null | `0` | any | none | +| boolean (True) | `1` | BOOLEAN | none | +| boolean (False) | `2` | BOOLEAN | none | +| int8 | `3` | INT(8, signed) | 1 byte | +| int16 | `4` | INT(16, signed) | 2 byte little-endian | +| int32 | `5` | INT(32, signed) | 4 byte little-endian | +| int64 | `6` | INT(64, signed) | 8 byte little-endian | +| double | `7` | DOUBLE | IEEE little-endian | +| decimal4 | `8` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | +| decimal8 | `9` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | +| decimal16 | `10` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | +| date | `11` | DATE | 4 byte little-endian | +| timestamp | `12` | TIMESTAMP(true, MICROS) | 8-byte little-endian | +| timestamp without time zone | `13` | TIMESTAMP(false, MICROS) | 8-byte little-endian | +| float | `14` | FLOAT | IEEE little-endian | +| binary | `15` | BINARY | 4 byte little-endian size, followed by bytes | +| string | `16` | STRING | 4 byte little-endian size, followed by UTF-8 encoded bytes | +| binary from metadata | `17` | BINARY | Little-endian index into the metadata dictionary. Number of bytes is equal to the metadata `offset_size`. | +| string from metadata | `18` | STRING | Little-endian index into the metadata dictionary. Number of bytes is equal to the metadata `offset_size`. | +| year-month interval | `19` | INT(32, signed)1 | 1 byte denoting start field (1 bit) and end field (1 bit) starting at LSB followed by 4-byte little-endian value. | +| day-time interval | `20` | INT(64, signed)1 | 1 byte denoting start field (2 bits) and end field (2 bits) starting at LSB followed by 8-byte little-endian value. | | Decimal Precision | Decimal value type | |-----------------------|--------------------| @@ -368,7 +368,7 @@ The Decimal type contains a scale, but no precision. The implied precision of a The year-month and day-time interval types have one byte at the beginning indicating the start and end fields. In the case of the year-month interval, the least significant bit denotes the start field and the next least significant bit denotes the end field. The remaining 6 bits are unused. A field value of 0 represents YEAR and 1 represents MONTH. In the case of the day-time interval, the least significant 2 bits denote the start field and the next least significant 2 bits denote the end field. The remaining 4 bits are unused. A field value of 0 represents DAY, 1 represents HOUR, 2 represents MINUTE, and 3 represents SECOND. -[*] The parquet format does not have pure equivalents for the year-month and day-time interval types. Year-month intervals are usually represented using int32 values and the day-time intervals are usually represented using int64 values. However, these values don't include the start and end fields of these types. Therefore, Spark stores them in the column metadata. +[1] The parquet format does not have pure equivalents for the year-month and day-time interval types. Year-month intervals are usually represented using int32 values and the day-time intervals are usually represented using int64 values. However, these values don't include the start and end fields of these types. Therefore, Spark stores them in the column metadata. # Field ID order and uniqueness From 2ce92736822f8ee7e03f7766620ac85165af0875 Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Wed, 24 Jul 2024 17:15:53 -0700 Subject: [PATCH 10/17] comment change --- .../scala/org/apache/spark/util/DayTimeIntervalUtils.java | 2 +- .../org/apache/spark/util/YearMonthIntervalUtils.java | 2 ++ python/pyspark/sql/variant_utils.py | 8 ++++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/common/utils/src/main/scala/org/apache/spark/util/DayTimeIntervalUtils.java b/common/utils/src/main/scala/org/apache/spark/util/DayTimeIntervalUtils.java index 554eb8a154a46..ce86ee1523936 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/DayTimeIntervalUtils.java +++ b/common/utils/src/main/scala/org/apache/spark/util/DayTimeIntervalUtils.java @@ -58,7 +58,7 @@ public static String fieldToString(byte field) throws SparkException { } } - // Used to convert microseconds representing a day-time interval with a given start and end field + // Used to convert microseconds representing a day-time interval with given start and end fields // to its ANSI SQL string representation. Throws a SparkException if startField or endField are // out of bounds. public static String toDayTimeIntervalANSIString(long micros, byte startField, byte endField) diff --git a/common/utils/src/main/scala/org/apache/spark/util/YearMonthIntervalUtils.java b/common/utils/src/main/scala/org/apache/spark/util/YearMonthIntervalUtils.java index 9d867743c018e..e627c57ebba70 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/YearMonthIntervalUtils.java +++ b/common/utils/src/main/scala/org/apache/spark/util/YearMonthIntervalUtils.java @@ -24,6 +24,8 @@ public class YearMonthIntervalUtils { private static byte MONTH = 1; private static int MONTHS_PER_YEAR = 12; + // Used to convert months representing a year-month interval with given start and end fields + // to its ANSI SQL string representation. public static String toYearMonthIntervalANSIString(int months, byte startField, byte endField) { String sign = ""; long absMonths = months; diff --git a/python/pyspark/sql/variant_utils.py b/python/pyspark/sql/variant_utils.py index 95df421c0d493..afe548dcaa420 100644 --- a/python/pyspark/sql/variant_utils.py +++ b/python/pyspark/sql/variant_utils.py @@ -431,6 +431,10 @@ def _get_type(cls, value: bytes, pos: int) -> Any: @classmethod def _to_year_month_interval_ansi_string(cls, months: int, start_field: int, end_field: int): + """ + Used to convert months representing a year-month interval with given start and end + fields to its ANSI SQL string representation. + """ YEAR = 0 MONTH = 1 MONTHS_PER_YEAR = 12 @@ -454,6 +458,10 @@ def _to_year_month_interval_ansi_string(cls, months: int, start_field: int, end_ @classmethod def _to_day_time_interval_ansi_string(cls, micros: int, start_field: int, end_field: int) -> str: + """ + Used to convert microseconds representing a day-tine interval with given start and end + fields to its ANSI SQL string representation. + """ DAY = 0 HOUR = 1 MINUTE = 2 From 9ab3ccbe1db0336ec5e1952b937698077ddb747d Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Wed, 24 Jul 2024 17:19:40 -0700 Subject: [PATCH 11/17] changed method name --- python/pyspark/sql/variant_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/variant_utils.py b/python/pyspark/sql/variant_utils.py index afe548dcaa420..cf185db3471ee 100644 --- a/python/pyspark/sql/variant_utils.py +++ b/python/pyspark/sql/variant_utils.py @@ -293,7 +293,7 @@ def _get_yminterval_info(cls, value: bytes, pos: int) -> Tuple[int, int, int]: raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) @classmethod - def _get_timedelta(cls, value: bytes, pos: int) -> Tuple[int, int, int]: + def _get_dtinterval_info(cls, value: bytes, pos: int) -> Tuple[int, int, int]: """ Returns the (micros, start_field, end_field) tuple from a day-time interval value at a given position in a variant. @@ -588,7 +588,7 @@ def handle_array(value_pos_list: List[int]) -> str: return cls._handle_array(value, pos, handle_array) elif variant_type == datetime.timedelta: - micros, start_field, end_field = cls._get_timedelta(value, pos) + micros, start_field, end_field = cls._get_dtinterval_info(value, pos) return '"' + cls._to_day_time_interval_ansi_string(micros, start_field, end_field) + '"' elif variant_type == cls._PlaceholderYearMonthIntervalInternalType: months, start_field, end_field = cls._get_yminterval_info(value, pos) @@ -634,7 +634,7 @@ def handle_array(value_pos_list: List[int]) -> List[Any]: elif variant_type == datetime.timedelta: # day-time intervals are represented using timedelta in a trivial manner return datetime.timedelta( - microseconds=cls._get_timedelta(value, pos)[0] + microseconds=cls._get_dtinterval_info(value, pos)[0] ) elif variant_type == cls._PlaceholderYearMonthIntervalInternalType: raise PySparkNotImplementedError( From 9aecb2cd1286eaca9166aba7137c5e8f560b8160 Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Thu, 25 Jul 2024 10:19:59 -0700 Subject: [PATCH 12/17] reformat --- python/pyspark/sql/tests/test_types.py | 276 +++++++++++++------------ python/pyspark/sql/variant_utils.py | 20 +- 2 files changed, 149 insertions(+), 147 deletions(-) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 9482fc35671fc..8610ace52d86a 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -2127,97 +2127,97 @@ def test_variant_type(self): ).collect()[0] # Highest possible DT interval value high_dt_interval_columns = self.spark.sql( - "select 9223372036854.775807::interval day to second::variant as dti00, " + - "9223372036854.775807::interval hour to second::variant as dti01, " + - "9223372036854.775807::interval minute to second::variant as dti02, " + - "9223372036854.775807::interval second::variant as dti03, " + - "153722867280.912930::interval day to minute::variant as dti10, " + - "153722867280.912930::interval hour to minute::variant as dti11, " + - "153722867280.912930::interval minute::variant as dti12, " + - "2562047788.015215::interval day to hour::variant as dti20, " + - "2562047788.015215::interval hour::variant as dti21, " + - "106751991.167300::interval day::variant as dti30" + "select 9223372036854.775807::interval day to second::variant as dti00, " + + "9223372036854.775807::interval hour to second::variant as dti01, " + + "9223372036854.775807::interval minute to second::variant as dti02, " + + "9223372036854.775807::interval second::variant as dti03, " + + "153722867280.912930::interval day to minute::variant as dti10, " + + "153722867280.912930::interval hour to minute::variant as dti11, " + + "153722867280.912930::interval minute::variant as dti12, " + + "2562047788.015215::interval day to hour::variant as dti20, " + + "2562047788.015215::interval hour::variant as dti21, " + + "106751991.167300::interval day::variant as dti30" ).collect()[0] # Lowest possible DT interval value low_dt_interval_columns = self.spark.sql( - "select -9223372036854.775808::interval day to second::variant as dti00, " + - "-9223372036854.775808::interval hour to second::variant as dti01, " + - "-9223372036854.775808::interval minute to second::variant as dti02, " + - "-9223372036854.775808::interval second::variant as dti03, " + - "-153722867280.912930::interval day to minute::variant as dti10, " + - "-153722867280.912930::interval hour to minute::variant as dti11, " + - "-153722867280.912930::interval minute::variant as dti12, " + - "-2562047788.015215::interval day to hour::variant as dti20, " + - "-2562047788.015215::interval hour::variant as dti21, " + - "-106751991.167300::interval day::variant as dti30" + "select -9223372036854.775808::interval day to second::variant as dti00, " + + "-9223372036854.775808::interval hour to second::variant as dti01, " + + "-9223372036854.775808::interval minute to second::variant as dti02, " + + "-9223372036854.775808::interval second::variant as dti03, " + + "-153722867280.912930::interval day to minute::variant as dti10, " + + "-153722867280.912930::interval hour to minute::variant as dti11, " + + "-153722867280.912930::interval minute::variant as dti12, " + + "-2562047788.015215::interval day to hour::variant as dti20, " + + "-2562047788.015215::interval hour::variant as dti21, " + + "-106751991.167300::interval day::variant as dti30" ).collect()[0] zero_dt_interval_columns = self.spark.sql( - "select 0::interval day to second::variant as dti00, " + - "0::interval hour to second::variant as dti01, " + - "0::interval minute to second::variant as dti02, " + - "0::interval second::variant as dti03, " + - "0::interval day to minute::variant as dti10, " + - "0::interval hour to minute::variant as dti11, " + - "0::interval minute::variant as dti12, " + - "0::interval day to hour::variant as dti20, " + - "0::interval hour::variant as dti21, " + - "0::interval day::variant as dti30" + "select 0::interval day to second::variant as dti00, " + + "0::interval hour to second::variant as dti01, " + + "0::interval minute to second::variant as dti02, " + + "0::interval second::variant as dti03, " + + "0::interval day to minute::variant as dti10, " + + "0::interval hour to minute::variant as dti11, " + + "0::interval minute::variant as dti12, " + + "0::interval day to hour::variant as dti20, " + + "0::interval hour::variant as dti21, " + + "0::interval day::variant as dti30" ).collect()[0] # Random positive dt interval value rand_pos_dt_interval_columns = self.spark.sql( - "select 12893121435::interval day to second::variant as dti00, " + - "273457447832::interval hour to second::variant as dti01, " + - "234233247::interval minute to second::variant as dti02, " + - "9310354::interval second::variant as dti03, " + - "214885357::interval day to minute::variant as dti10, " + - "4557624130::interval hour to minute::variant as dti11, " + - "3903887::interval minute::variant as dti12, " + - "3581422::interval day to hour::variant as dti20, " + - "75960402::interval hour::variant as dti21, " + - "65064::interval day::variant as dti30" + "select 12893121435::interval day to second::variant as dti00, " + + "273457447832::interval hour to second::variant as dti01, " + + "234233247::interval minute to second::variant as dti02, " + + "9310354::interval second::variant as dti03, " + + "214885357::interval day to minute::variant as dti10, " + + "4557624130::interval hour to minute::variant as dti11, " + + "3903887::interval minute::variant as dti12, " + + "3581422::interval day to hour::variant as dti20, " + + "75960402::interval hour::variant as dti21, " + + "65064::interval day::variant as dti30" ).collect()[0] # Random negative dt interval value rand_neg_dt_interval_columns = self.spark.sql( - "select -426547473652::interval day to second::variant as dti00, " + - "-2327834334::interval hour to second::variant as dti01, " + - "-324223232::interval minute to second::variant as dti02, " + - "-2342332::interval second::variant as dti03, " + - "-7109124560::interval day to minute::variant as dti10, " + - "-38797238::interval hour to minute::variant as dti11, " + - "-5403720::interval minute::variant as dti12, " + - "-118485409::interval day to hour::variant as dti20, " + - "-646620::interval hour::variant as dti21, " + - "-90062::interval day::variant as dti30" + "select -426547473652::interval day to second::variant as dti00, " + + "-2327834334::interval hour to second::variant as dti01, " + + "-324223232::interval minute to second::variant as dti02, " + + "-2342332::interval second::variant as dti03, " + + "-7109124560::interval day to minute::variant as dti10, " + + "-38797238::interval hour to minute::variant as dti11, " + + "-5403720::interval minute::variant as dti12, " + + "-118485409::interval day to hour::variant as dti20, " + + "-646620::interval hour::variant as dti21, " + + "-90062::interval day::variant as dti30" ).collect()[0] # Highest possible ym interval value high_ym_interval_columns = self.spark.sql( - "select 2147483647::interval year to month::variant as ymi0, " + - "2147483647::interval month::variant as ymi1, " + - "178956970::interval year::variant as ymi2" + "select 2147483647::interval year to month::variant as ymi0, " + + "2147483647::interval month::variant as ymi1, " + + "178956970::interval year::variant as ymi2" ).collect()[0] # Lowest possible ym interval value low_ym_interval_columns = self.spark.sql( - "select -2147483648::interval year to month::variant as ymi0, " + - "-2147483648::interval month::variant as ymi1, " + - "-178956970::interval year::variant as ymi2" + "select -2147483648::interval year to month::variant as ymi0, " + + "-2147483648::interval month::variant as ymi1, " + + "-178956970::interval year::variant as ymi2" ).collect()[0] zero_ym_interval_columns = self.spark.sql( - "select 0::interval year to month::variant ymi0, " + - "0::interval month::variant ymi1, " + - "0::interval year::variant ymi2" + "select 0::interval year to month::variant ymi0, " + + "0::interval month::variant ymi1, " + + "0::interval year::variant ymi2" ).collect()[0] # Random positive ym interval value rand_pos_ym_interval_columns = self.spark.sql( - "select 24678537::interval year to month::variant ymi0, " + - "345763467::interval month::variant ymi1, " + - "45723888::interval year::variant ymi2" + "select 24678537::interval year to month::variant ymi0, " + + "345763467::interval month::variant ymi1, " + + "45723888::interval year::variant ymi2" ).collect()[0] # Random negative ym interval value rand_neg_ym_interval_columns = self.spark.sql( - "select -425245345::interval year to month::variant ymi0, " + - "-849348229::interval month::variant ymi1, " + - "-85349890::interval year::variant ymi2" + "select -425245345::interval year to month::variant ymi0, " + + "-849348229::interval month::variant ymi1, " + + "-85349890::interval year::variant ymi2" ).collect()[0] variants = [ @@ -2324,77 +2324,81 @@ def test_variant_type(self): self.assertEqual(str(variants[12]), '"1940-01-01 05:05:13.123000+00:00"') self.assertEqual(str(variants[13]), '"2522-12-31 05:23:00+00:00"') self.assertEqual(str(variants[14]), '"0001-12-30 17:01:01+00:00"') - self.assertEqual(str(variants[15]), - '"INTERVAL \'106751991 04:00:54.775807\' DAY TO SECOND"') - self.assertEqual(str(variants[16]), - '"INTERVAL \'2562047788:00:54.775807\' HOUR TO SECOND"') - self.assertEqual(str(variants[17]), - '"INTERVAL \'153722867280:54.775807\' MINUTE TO SECOND"') - self.assertEqual(str(variants[18]), '"INTERVAL \'9223372036854.775807\' SECOND"') - self.assertEqual(str(variants[19]), '"INTERVAL \'106751991 04:00\' DAY TO MINUTE"') - self.assertEqual(str(variants[20]), '"INTERVAL \'2562047788:00\' HOUR TO MINUTE"') - self.assertEqual(str(variants[21]), '"INTERVAL \'153722867280\' MINUTE"') - self.assertEqual(str(variants[22]), '"INTERVAL \'106751991 04\' DAY TO HOUR"') - self.assertEqual(str(variants[23]), '"INTERVAL \'2562047788\' HOUR"') - self.assertEqual(str(variants[24]), '"INTERVAL \'106751991\' DAY"') - self.assertEqual(str(variants[25]), - '"INTERVAL \'-106751991 04:00:54.775808\' DAY TO SECOND"') - self.assertEqual(str(variants[26]), - '"INTERVAL \'-2562047788:00:54.775808\' HOUR TO SECOND"') - self.assertEqual(str(variants[27]), - '"INTERVAL \'-153722867280:54.775808\' MINUTE TO SECOND"') - self.assertEqual(str(variants[28]), '"INTERVAL \'-9223372036854.775808\' SECOND"') - self.assertEqual(str(variants[29]), '"INTERVAL \'-106751991 04:00\' DAY TO MINUTE"') - self.assertEqual(str(variants[30]), '"INTERVAL \'-2562047788:00\' HOUR TO MINUTE"') - self.assertEqual(str(variants[31]), '"INTERVAL \'-153722867280\' MINUTE"') - self.assertEqual(str(variants[32]), '"INTERVAL \'-106751991 04\' DAY TO HOUR"') - self.assertEqual(str(variants[33]), '"INTERVAL \'-2562047788\' HOUR"') - self.assertEqual(str(variants[34]), '"INTERVAL \'-106751991\' DAY"') - self.assertEqual(str(variants[35]), '"INTERVAL \'0 00:00:00\' DAY TO SECOND"') - self.assertEqual(str(variants[36]), '"INTERVAL \'00:00:00\' HOUR TO SECOND"') - self.assertEqual(str(variants[37]), '"INTERVAL \'00:00\' MINUTE TO SECOND"') - self.assertEqual(str(variants[38]), '"INTERVAL \'00\' SECOND"') - self.assertEqual(str(variants[39]), '"INTERVAL \'0 00:00\' DAY TO MINUTE"') - self.assertEqual(str(variants[40]), '"INTERVAL \'00:00\' HOUR TO MINUTE"') - self.assertEqual(str(variants[41]), '"INTERVAL \'00\' MINUTE"') - self.assertEqual(str(variants[42]), '"INTERVAL \'0 00\' DAY TO HOUR"') - self.assertEqual(str(variants[43]), '"INTERVAL \'00\' HOUR"') - self.assertEqual(str(variants[44]), '"INTERVAL \'0\' DAY"') - self.assertEqual(str(variants[45]), '"INTERVAL \'149225 22:37:15\' DAY TO SECOND"') - self.assertEqual(str(variants[46]), '"INTERVAL \'75960402:10:32\' HOUR TO SECOND"') - self.assertEqual(str(variants[47]), '"INTERVAL \'3903887:27\' MINUTE TO SECOND"') - self.assertEqual(str(variants[48]), '"INTERVAL \'9310354\' SECOND"') - self.assertEqual(str(variants[49]), '"INTERVAL \'149225 22:37\' DAY TO MINUTE"') - self.assertEqual(str(variants[50]), '"INTERVAL \'75960402:10\' HOUR TO MINUTE"') - self.assertEqual(str(variants[51]), '"INTERVAL \'3903887\' MINUTE"') - self.assertEqual(str(variants[52]), '"INTERVAL \'149225 22\' DAY TO HOUR"') - self.assertEqual(str(variants[53]), '"INTERVAL \'75960402\' HOUR"') - self.assertEqual(str(variants[54]), '"INTERVAL \'65064\' DAY"') - self.assertEqual(str(variants[55]), '"INTERVAL \'-4936892 01:20:52\' DAY TO SECOND"') - self.assertEqual(str(variants[56]), '"INTERVAL \'-646620:38:54\' HOUR TO SECOND"') - self.assertEqual(str(variants[57]), '"INTERVAL \'-5403720:32\' MINUTE TO SECOND"') - self.assertEqual(str(variants[58]), '"INTERVAL \'-2342332\' SECOND"') - self.assertEqual(str(variants[59]), '"INTERVAL \'-4936892 01:20\' DAY TO MINUTE"') - self.assertEqual(str(variants[60]), '"INTERVAL \'-646620:38\' HOUR TO MINUTE"') - self.assertEqual(str(variants[61]), '"INTERVAL \'-5403720\' MINUTE"') - self.assertEqual(str(variants[62]), '"INTERVAL \'-4936892 01\' DAY TO HOUR"') - self.assertEqual(str(variants[63]), '"INTERVAL \'-646620\' HOUR"') - self.assertEqual(str(variants[64]), '"INTERVAL \'-90062\' DAY"') - self.assertEqual(str(variants[65]), '"INTERVAL \'178956970-7\' YEAR TO MONTH"') - self.assertEqual(str(variants[66]), '"INTERVAL \'2147483647\' MONTH"') - self.assertEqual(str(variants[67]), '"INTERVAL \'178956970\' YEAR"') - self.assertEqual(str(variants[68]), '"INTERVAL \'-178956970-8\' YEAR TO MONTH"') - self.assertEqual(str(variants[69]), '"INTERVAL \'-2147483648\' MONTH"') - self.assertEqual(str(variants[70]), '"INTERVAL \'-178956970\' YEAR"') - self.assertEqual(str(variants[71]), '"INTERVAL \'0-0\' YEAR TO MONTH"') - self.assertEqual(str(variants[72]), '"INTERVAL \'0\' MONTH"') - self.assertEqual(str(variants[73]), '"INTERVAL \'0\' YEAR"') - self.assertEqual(str(variants[74]), '"INTERVAL \'2056544-9\' YEAR TO MONTH"') - self.assertEqual(str(variants[75]), '"INTERVAL \'345763467\' MONTH"') - self.assertEqual(str(variants[76]), '"INTERVAL \'45723888\' YEAR"') - self.assertEqual(str(variants[77]), '"INTERVAL \'-35437112-1\' YEAR TO MONTH"') - self.assertEqual(str(variants[78]), '"INTERVAL \'-849348229\' MONTH"') - self.assertEqual(str(variants[79]), '"INTERVAL \'-85349890\' YEAR"') + self.assertEqual( + str(variants[15]), "\"INTERVAL '106751991 04:00:54.775807' DAY TO SECOND\"" + ) + self.assertEqual(str(variants[16]), "\"INTERVAL '2562047788:00:54.775807' HOUR TO SECOND\"") + self.assertEqual( + str(variants[17]), "\"INTERVAL '153722867280:54.775807' MINUTE TO SECOND\"" + ) + self.assertEqual(str(variants[18]), "\"INTERVAL '9223372036854.775807' SECOND\"") + self.assertEqual(str(variants[19]), "\"INTERVAL '106751991 04:00' DAY TO MINUTE\"") + self.assertEqual(str(variants[20]), "\"INTERVAL '2562047788:00' HOUR TO MINUTE\"") + self.assertEqual(str(variants[21]), "\"INTERVAL '153722867280' MINUTE\"") + self.assertEqual(str(variants[22]), "\"INTERVAL '106751991 04' DAY TO HOUR\"") + self.assertEqual(str(variants[23]), "\"INTERVAL '2562047788' HOUR\"") + self.assertEqual(str(variants[24]), "\"INTERVAL '106751991' DAY\"") + self.assertEqual( + str(variants[25]), "\"INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND\"" + ) + self.assertEqual( + str(variants[26]), "\"INTERVAL '-2562047788:00:54.775808' HOUR TO SECOND\"" + ) + self.assertEqual( + str(variants[27]), "\"INTERVAL '-153722867280:54.775808' MINUTE TO SECOND\"" + ) + self.assertEqual(str(variants[28]), "\"INTERVAL '-9223372036854.775808' SECOND\"") + self.assertEqual(str(variants[29]), "\"INTERVAL '-106751991 04:00' DAY TO MINUTE\"") + self.assertEqual(str(variants[30]), "\"INTERVAL '-2562047788:00' HOUR TO MINUTE\"") + self.assertEqual(str(variants[31]), "\"INTERVAL '-153722867280' MINUTE\"") + self.assertEqual(str(variants[32]), "\"INTERVAL '-106751991 04' DAY TO HOUR\"") + self.assertEqual(str(variants[33]), "\"INTERVAL '-2562047788' HOUR\"") + self.assertEqual(str(variants[34]), "\"INTERVAL '-106751991' DAY\"") + self.assertEqual(str(variants[35]), "\"INTERVAL '0 00:00:00' DAY TO SECOND\"") + self.assertEqual(str(variants[36]), "\"INTERVAL '00:00:00' HOUR TO SECOND\"") + self.assertEqual(str(variants[37]), "\"INTERVAL '00:00' MINUTE TO SECOND\"") + self.assertEqual(str(variants[38]), "\"INTERVAL '00' SECOND\"") + self.assertEqual(str(variants[39]), "\"INTERVAL '0 00:00' DAY TO MINUTE\"") + self.assertEqual(str(variants[40]), "\"INTERVAL '00:00' HOUR TO MINUTE\"") + self.assertEqual(str(variants[41]), "\"INTERVAL '00' MINUTE\"") + self.assertEqual(str(variants[42]), "\"INTERVAL '0 00' DAY TO HOUR\"") + self.assertEqual(str(variants[43]), "\"INTERVAL '00' HOUR\"") + self.assertEqual(str(variants[44]), "\"INTERVAL '0' DAY\"") + self.assertEqual(str(variants[45]), "\"INTERVAL '149225 22:37:15' DAY TO SECOND\"") + self.assertEqual(str(variants[46]), "\"INTERVAL '75960402:10:32' HOUR TO SECOND\"") + self.assertEqual(str(variants[47]), "\"INTERVAL '3903887:27' MINUTE TO SECOND\"") + self.assertEqual(str(variants[48]), "\"INTERVAL '9310354' SECOND\"") + self.assertEqual(str(variants[49]), "\"INTERVAL '149225 22:37' DAY TO MINUTE\"") + self.assertEqual(str(variants[50]), "\"INTERVAL '75960402:10' HOUR TO MINUTE\"") + self.assertEqual(str(variants[51]), "\"INTERVAL '3903887' MINUTE\"") + self.assertEqual(str(variants[52]), "\"INTERVAL '149225 22' DAY TO HOUR\"") + self.assertEqual(str(variants[53]), "\"INTERVAL '75960402' HOUR\"") + self.assertEqual(str(variants[54]), "\"INTERVAL '65064' DAY\"") + self.assertEqual(str(variants[55]), "\"INTERVAL '-4936892 01:20:52' DAY TO SECOND\"") + self.assertEqual(str(variants[56]), "\"INTERVAL '-646620:38:54' HOUR TO SECOND\"") + self.assertEqual(str(variants[57]), "\"INTERVAL '-5403720:32' MINUTE TO SECOND\"") + self.assertEqual(str(variants[58]), "\"INTERVAL '-2342332' SECOND\"") + self.assertEqual(str(variants[59]), "\"INTERVAL '-4936892 01:20' DAY TO MINUTE\"") + self.assertEqual(str(variants[60]), "\"INTERVAL '-646620:38' HOUR TO MINUTE\"") + self.assertEqual(str(variants[61]), "\"INTERVAL '-5403720' MINUTE\"") + self.assertEqual(str(variants[62]), "\"INTERVAL '-4936892 01' DAY TO HOUR\"") + self.assertEqual(str(variants[63]), "\"INTERVAL '-646620' HOUR\"") + self.assertEqual(str(variants[64]), "\"INTERVAL '-90062' DAY\"") + self.assertEqual(str(variants[65]), "\"INTERVAL '178956970-7' YEAR TO MONTH\"") + self.assertEqual(str(variants[66]), "\"INTERVAL '2147483647' MONTH\"") + self.assertEqual(str(variants[67]), "\"INTERVAL '178956970' YEAR\"") + self.assertEqual(str(variants[68]), "\"INTERVAL '-178956970-8' YEAR TO MONTH\"") + self.assertEqual(str(variants[69]), "\"INTERVAL '-2147483648' MONTH\"") + self.assertEqual(str(variants[70]), "\"INTERVAL '-178956970' YEAR\"") + self.assertEqual(str(variants[71]), "\"INTERVAL '0-0' YEAR TO MONTH\"") + self.assertEqual(str(variants[72]), "\"INTERVAL '0' MONTH\"") + self.assertEqual(str(variants[73]), "\"INTERVAL '0' YEAR\"") + self.assertEqual(str(variants[74]), "\"INTERVAL '2056544-9' YEAR TO MONTH\"") + self.assertEqual(str(variants[75]), "\"INTERVAL '345763467' MONTH\"") + self.assertEqual(str(variants[76]), "\"INTERVAL '45723888' YEAR\"") + self.assertEqual(str(variants[77]), "\"INTERVAL '-35437112-1' YEAR TO MONTH\"") + self.assertEqual(str(variants[78]), "\"INTERVAL '-849348229' MONTH\"") + self.assertEqual(str(variants[79]), "\"INTERVAL '-85349890' YEAR\"") # Check to_json on timestamps with custom timezones self.assertEqual( diff --git a/python/pyspark/sql/variant_utils.py b/python/pyspark/sql/variant_utils.py index cf185db3471ee..99b952da3997a 100644 --- a/python/pyspark/sql/variant_utils.py +++ b/python/pyspark/sql/variant_utils.py @@ -453,11 +453,12 @@ def _to_year_month_interval_ansi_string(cls, months: int, start_field: int, end_ format_builder.append(str(months) + "' MONTH") else: format_builder.append(year_and_month + "' YEAR TO MONTH") - return ''.join(format_builder) + return "".join(format_builder) @classmethod - def _to_day_time_interval_ansi_string(cls, - micros: int, start_field: int, end_field: int) -> str: + def _to_day_time_interval_ansi_string( + cls, micros: int, start_field: int, end_field: int + ) -> str: """ Used to convert microseconds representing a day-tine interval with given start and end fields to its ANSI SQL string representation. @@ -558,9 +559,7 @@ def field_to_string(field: int) -> str: if start_field < SECOND and SECOND <= end_field: lead_zero = "0" if (rest < 10 * MICROS_PER_SECOND) else "" format_builder.append( - ":" - + lead_zero - + (Decimal(rest) / Decimal(1000000)).normalize().to_eng_string() + ":" + lead_zero + (Decimal(rest) / Decimal(1000000)).normalize().to_eng_string() ) return prefix + ("".join(format_builder) % tuple(format_args)) + postfix @@ -592,8 +591,9 @@ def handle_array(value_pos_list: List[int]) -> str: return '"' + cls._to_day_time_interval_ansi_string(micros, start_field, end_field) + '"' elif variant_type == cls._PlaceholderYearMonthIntervalInternalType: months, start_field, end_field = cls._get_yminterval_info(value, pos) - return '"' + cls._to_year_month_interval_ansi_string(months, start_field, - end_field) + '"' + return ( + '"' + cls._to_year_month_interval_ansi_string(months, start_field, end_field) + '"' + ) else: value = cls._get_scalar(variant_type, value, metadata, pos, zone_id) if value is None: @@ -633,9 +633,7 @@ def handle_array(value_pos_list: List[int]) -> List[Any]: return cls._handle_array(value, pos, handle_array) elif variant_type == datetime.timedelta: # day-time intervals are represented using timedelta in a trivial manner - return datetime.timedelta( - microseconds=cls._get_dtinterval_info(value, pos)[0] - ) + return datetime.timedelta(microseconds=cls._get_dtinterval_info(value, pos)[0]) elif variant_type == cls._PlaceholderYearMonthIntervalInternalType: raise PySparkNotImplementedError( errorClass="NOT_IMPLEMENTED", From 73cfde81a2e32771c7dcfa5568bcf4ecfe34767a Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Thu, 25 Jul 2024 11:08:08 -0700 Subject: [PATCH 13/17] resolved Gene's comments --- .../spark/types/variant/VariantUtil.java | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java index ac856fb57890d..cb2336d9cff6e 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java @@ -121,10 +121,12 @@ public class VariantUtil { // string size) + (size bytes of string content). public static final int LONG_STR = 16; // year-month interval value. The content is one byte representing the start and end field values - // (1 bit each starting at least significant bits) and a 4-byte little-endian signed integer + // (1 bit each starting at least significant bits) and a 4-byte little-endian signed integer. + // A field value of 0 indicates year and a field value of 1 indicates month. public static final int YEAR_MONTH_INTERVAL = 19; // day-time interval value. The content is one byte representing the start and end field values - // (2 bits each starting at least significant bits) and an 8-byte little-endian signed integer + // (2 bits each starting at least significant bits) and an 8-byte little-endian signed integer. + // A field value of 0 indicates day, 1 indicates hour, 2 indicates minute, and 3 indicates second. public static final int DAY_TIME_INTERVAL = 20; public static final byte VERSION = 1; @@ -384,8 +386,8 @@ public static boolean getBoolean(byte[] value, int pos) { // is guaranteed to fit into an int and represents the number of days from the Unix epoch. // If the type is `TIMESTAMP/TIMESTAMP_NTZ`, the return value represents the number of // microseconds from the Unix epoch. If the type is `YEAR_MONTH_INTERVAL`, the return value - // represents the number of months in the interval. If the type is `DAY_TIME_INTERVAL`, the - // return value represents the number of microseconds in the interval. + // is guaranteed to fit in an int and represents the number of months in the interval. If the type + // is `DAY_TIME_INTERVAL`, the return value represents the number of microseconds in the interval. // Throw `MALFORMED_VARIANT` if the variant is malformed. public static long getLong(byte[] value, int pos) { checkIndex(pos, value.length); @@ -429,6 +431,11 @@ public IntervalFields(byte startField, byte endField) { // returned array contains the start field at the zeroth index and the end field at the first // index. public static IntervalFields getYearMonthIntervalFields(byte[] value, int pos) { + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK; + if (basicType != PRIMITIVE || typeInfo != YEAR_MONTH_INTERVAL) { + throw unexpectedType(Type.YEAR_MONTH_INTERVAL); + } long fieldInfo = readLong(value, pos + 1, 1); IntervalFields intervalFields = new IntervalFields((byte) (fieldInfo & 0x1), (byte) ((fieldInfo >> 1) & 0x1)); @@ -442,6 +449,11 @@ public static IntervalFields getYearMonthIntervalFields(byte[] value, int pos) { // returned array contains the start field at the zeroth index and the end field at the first // index. public static IntervalFields getDayTimeIntervalFields(byte[] value, int pos) { + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK; + if (basicType != PRIMITIVE || typeInfo != DAY_TIME_INTERVAL) { + throw unexpectedType(Type.DAY_TIME_INTERVAL); + } long fieldInfo = readLong(value, pos + 1, 1); IntervalFields intervalFields = new IntervalFields((byte) (fieldInfo & 0x3), (byte) ((fieldInfo >> 2) & 0x3)); From d3e419346816a77680f141eb04deab46d2f3c9ae Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Thu, 25 Jul 2024 14:02:46 -0700 Subject: [PATCH 14/17] lint fix --- python/pyspark/sql/variant_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/variant_utils.py b/python/pyspark/sql/variant_utils.py index 99b952da3997a..4924ab9cb4157 100644 --- a/python/pyspark/sql/variant_utils.py +++ b/python/pyspark/sql/variant_utils.py @@ -436,7 +436,6 @@ def _to_year_month_interval_ansi_string(cls, months: int, start_field: int, end_ fields to its ANSI SQL string representation. """ YEAR = 0 - MONTH = 1 MONTHS_PER_YEAR = 12 sign = "" abs_months = months From 417e419194a15ead29eed376e8fe519a3ff0afc8 Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Thu, 25 Jul 2024 17:25:30 -0700 Subject: [PATCH 15/17] fix --- python/pyspark/sql/variant_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/variant_utils.py b/python/pyspark/sql/variant_utils.py index 4924ab9cb4157..eb1d24f0072c1 100644 --- a/python/pyspark/sql/variant_utils.py +++ b/python/pyspark/sql/variant_utils.py @@ -185,7 +185,7 @@ def _get_day_time_interval_fields(cls, value: bytes, pos: int) -> Tuple[int, int start_field = value[pos] & 0x3 end_field = (value[pos] >> 2) & 0x3 if end_field < start_field: - raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) + raise PySparkValueError(errorClass="MALFORMED_VARIANT", message_parameters={}) return (start_field, end_field) @classmethod @@ -198,7 +198,7 @@ def _get_year_month_interval_fields(cls, value: bytes, pos: int) -> Tuple[int, i start_field = value[pos] & 0x1 end_field = (value[pos] >> 1) & 0x1 if end_field < start_field: - raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) + raise PySparkValueError(errorClass="MALFORMED_VARIANT", message_parameters={}) return (start_field, end_field) @classmethod @@ -285,12 +285,12 @@ def _get_yminterval_info(cls, value: bytes, pos: int) -> Tuple[int, int, int]: cls._check_index(pos, len(value)) basic_type, type_info = cls._get_type_info(value, pos) if basic_type != VariantUtils.PRIMITIVE: - raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) + raise PySparkValueError(errorClass="MALFORMED_VARIANT", message_parameters={}) if type_info == VariantUtils.YEAR_MONTH_INTERVAL: months = cls._read_long(value, pos + 2, 4, signed=True) start_field, end_field = cls._get_year_month_interval_fields(value, pos + 1) return (months, start_field, end_field) - raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) + raise PySparkValueError(errorClass="MALFORMED_VARIANT", message_parameters={}) @classmethod def _get_dtinterval_info(cls, value: bytes, pos: int) -> Tuple[int, int, int]: @@ -301,12 +301,12 @@ def _get_dtinterval_info(cls, value: bytes, pos: int) -> Tuple[int, int, int]: cls._check_index(pos, len(value)) basic_type, type_info = cls._get_type_info(value, pos) if basic_type != VariantUtils.PRIMITIVE: - raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) + raise PySparkValueError(errorClass="MALFORMED_VARIANT", message_parameters={}) if type_info == VariantUtils.DAY_TIME_INTERVAL: micros = cls._read_long(value, pos + 2, 8, signed=True) start_field, end_field = cls._get_day_time_interval_fields(value, pos + 1) return (micros, start_field, end_field) - raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) + raise PySparkValueError(errorClass="MALFORMED_VARIANT", message_parameters={}) @classmethod def _get_string(cls, value: bytes, pos: int) -> str: @@ -486,10 +486,10 @@ def field_to_string(field: int) -> str: return "MINUTE" elif field == SECOND: return "SECOND" - raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) + raise PySparkValueError(errorClass="MALFORMED_VARIANT", message_parameters={}) if end_field < start_field: - raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) + raise PySparkValueError(errorClass="MALFORMED_VARIANT", message_parameters={}) sign = "" rest = micros from_str = field_to_string(start_field).upper() From 57660d69b4d7dbe87b0c0f2cd94ab20cd20668b8 Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Thu, 25 Jul 2024 20:00:47 -0700 Subject: [PATCH 16/17] fix --- python/pyspark/sql/variant_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/variant_utils.py b/python/pyspark/sql/variant_utils.py index eb1d24f0072c1..2db22edc66645 100644 --- a/python/pyspark/sql/variant_utils.py +++ b/python/pyspark/sql/variant_utils.py @@ -185,7 +185,7 @@ def _get_day_time_interval_fields(cls, value: bytes, pos: int) -> Tuple[int, int start_field = value[pos] & 0x3 end_field = (value[pos] >> 2) & 0x3 if end_field < start_field: - raise PySparkValueError(errorClass="MALFORMED_VARIANT", message_parameters={}) + raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={}) return (start_field, end_field) @classmethod @@ -198,7 +198,7 @@ def _get_year_month_interval_fields(cls, value: bytes, pos: int) -> Tuple[int, i start_field = value[pos] & 0x1 end_field = (value[pos] >> 1) & 0x1 if end_field < start_field: - raise PySparkValueError(errorClass="MALFORMED_VARIANT", message_parameters={}) + raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={}) return (start_field, end_field) @classmethod @@ -285,12 +285,12 @@ def _get_yminterval_info(cls, value: bytes, pos: int) -> Tuple[int, int, int]: cls._check_index(pos, len(value)) basic_type, type_info = cls._get_type_info(value, pos) if basic_type != VariantUtils.PRIMITIVE: - raise PySparkValueError(errorClass="MALFORMED_VARIANT", message_parameters={}) + raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={}) if type_info == VariantUtils.YEAR_MONTH_INTERVAL: months = cls._read_long(value, pos + 2, 4, signed=True) start_field, end_field = cls._get_year_month_interval_fields(value, pos + 1) return (months, start_field, end_field) - raise PySparkValueError(errorClass="MALFORMED_VARIANT", message_parameters={}) + raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={}) @classmethod def _get_dtinterval_info(cls, value: bytes, pos: int) -> Tuple[int, int, int]: @@ -301,12 +301,12 @@ def _get_dtinterval_info(cls, value: bytes, pos: int) -> Tuple[int, int, int]: cls._check_index(pos, len(value)) basic_type, type_info = cls._get_type_info(value, pos) if basic_type != VariantUtils.PRIMITIVE: - raise PySparkValueError(errorClass="MALFORMED_VARIANT", message_parameters={}) + raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={}) if type_info == VariantUtils.DAY_TIME_INTERVAL: micros = cls._read_long(value, pos + 2, 8, signed=True) start_field, end_field = cls._get_day_time_interval_fields(value, pos + 1) return (micros, start_field, end_field) - raise PySparkValueError(errorClass="MALFORMED_VARIANT", message_parameters={}) + raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={}) @classmethod def _get_string(cls, value: bytes, pos: int) -> str: @@ -486,10 +486,10 @@ def field_to_string(field: int) -> str: return "MINUTE" elif field == SECOND: return "SECOND" - raise PySparkValueError(errorClass="MALFORMED_VARIANT", message_parameters={}) + raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={}) if end_field < start_field: - raise PySparkValueError(errorClass="MALFORMED_VARIANT", message_parameters={}) + raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={}) sign = "" rest = micros from_str = field_to_string(start_field).upper() From 52fed53958022f454c553f49bccdc78e9397b28a Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Thu, 25 Jul 2024 22:08:49 -0700 Subject: [PATCH 17/17] fix --- python/pyspark/sql/variant_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/variant_utils.py b/python/pyspark/sql/variant_utils.py index 2db22edc66645..a6939b10c3018 100644 --- a/python/pyspark/sql/variant_utils.py +++ b/python/pyspark/sql/variant_utils.py @@ -430,7 +430,9 @@ def _get_type(cls, value: bytes, pos: int) -> Any: raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={}) @classmethod - def _to_year_month_interval_ansi_string(cls, months: int, start_field: int, end_field: int): + def _to_year_month_interval_ansi_string( + cls, months: int, start_field: int, end_field: int + ) -> str: """ Used to convert months representing a year-month interval with given start and end fields to its ANSI SQL string representation.