Skip to content

Commit

Permalink
SNOW-1234216 Native Arrow structured types array support (#1687)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-dheyman authored Apr 2, 2024
1 parent 6a209b0 commit 43d8fba
Show file tree
Hide file tree
Showing 16 changed files with 659 additions and 440 deletions.
6 changes: 1 addition & 5 deletions src/main/java/net/snowflake/client/core/ArrowSqlInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,7 @@ public Timestamp readTimestamp(TimeZone tz) throws SQLException {
converters
.getStructuredTypeDateTimeConverter()
.getTimestamp(
(JsonStringHashMap<String, Object>) value,
columnType,
columnSubType,
tz,
scale));
(Map<String, Object>) value, columnType, columnSubType, tz, scale));
});
}

Expand Down
172 changes: 162 additions & 10 deletions src/main/java/net/snowflake/client/core/SFArrowResultSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
import java.sql.Array;
import java.sql.Date;
import java.sql.SQLException;
import java.sql.SQLInput;
import java.sql.Time;
import java.sql.Timestamp;
import java.sql.Types;
import java.util.List;
import java.util.Map;
import java.util.TimeZone;
import java.util.stream.Stream;
import net.snowflake.client.core.arrow.ArrayConverter;
import net.snowflake.client.core.arrow.ArrowVectorConverter;
import net.snowflake.client.core.arrow.StructConverter;
import net.snowflake.client.core.arrow.VarCharConverter;
Expand All @@ -29,6 +33,8 @@
import net.snowflake.client.jdbc.ArrowResultChunk;
import net.snowflake.client.jdbc.ArrowResultChunk.ArrowChunkIterator;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.jdbc.FieldMetadata;
import net.snowflake.client.jdbc.SnowflakeColumnMetadata;
import net.snowflake.client.jdbc.SnowflakeResultSetSerializableV1;
import net.snowflake.client.jdbc.SnowflakeSQLException;
import net.snowflake.client.jdbc.SnowflakeSQLLoggedException;
Expand All @@ -38,6 +44,7 @@
import net.snowflake.client.jdbc.telemetry.TelemetryUtil;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;
import net.snowflake.client.util.Converter;
import net.snowflake.common.core.SFBinaryFormat;
import net.snowflake.common.core.SnowflakeDateTimeFormat;
import net.snowflake.common.core.SqlState;
Expand Down Expand Up @@ -199,7 +206,7 @@ public SFArrowResultSet(
this.timestampTZFormatter = resultSetSerializable.getTimestampTZFormatter();
this.dateFormatter = resultSetSerializable.getDateFormatter();
this.timeFormatter = resultSetSerializable.getTimeFormatter();
this.sessionTimezone = resultSetSerializable.getTimeZone();
this.sessionTimeZone = resultSetSerializable.getTimeZone();
this.binaryFormatter = resultSetSerializable.getBinaryFormatter();
this.resultSetMetaData = resultSetSerializable.getSFResultSetMetaData();
this.treatNTZAsUTC = resultSetSerializable.getTreatNTZAsUTC();
Expand Down Expand Up @@ -364,18 +371,45 @@ public Converters getConverters() {
}

@Override
@SnowflakeJdbcInternalApi
public SQLInput createSqlInputForColumn(
Object input, Class<?> parentObjectClass, int columnIndex, SFBaseSession session) {
if (parentObjectClass.equals(JsonSqlInput.class)) {
return createJsonSqlInputForColumn(input, columnIndex, session);
} else {
return new ArrowSqlInput(
(Map<String, Object>) input,
session,
converters,
resultSetMetaData.getColumnMetadata().get(columnIndex - 1).getFields());
}
}

@Override
@SnowflakeJdbcInternalApi
public Date convertToDate(Object object, TimeZone tz) throws SFException {
if (object instanceof String) {
return convertStringToDate((String) object, tz);
}
return converters.getStructuredTypeDateTimeConverter().getDate((int) object, tz);
}

@Override
@SnowflakeJdbcInternalApi
public Time convertToTime(Object object, int scale) throws SFException {
if (object instanceof String) {
return convertStringToTime((String) object, scale);
}
return converters.getStructuredTypeDateTimeConverter().getTime((long) object, scale);
}

@Override
@SnowflakeJdbcInternalApi
public Timestamp convertToTimestamp(
Object object, int columnType, int columnSubType, TimeZone tz, int scale) throws SFException {
if (object instanceof String) {
return convertStringToTimestamp((String) object, columnType, columnSubType, tz, scale);
}
return converters
.getStructuredTypeDateTimeConverter()
.getTimestamp(
Expand Down Expand Up @@ -497,7 +531,7 @@ public Date getDate(int columnIndex, TimeZone tz) throws SFException {
ArrowVectorConverter converter = currentChunkIterator.getCurrentConverter(columnIndex - 1);
int index = currentChunkIterator.getCurrentRowInRecordBatch();
wasNull = converter.isNull(index);
converter.setSessionTimeZone(sessionTimezone);
converter.setSessionTimeZone(sessionTimeZone);
converter.setUseSessionTimezone(useSessionTimezone);
return converter.toDate(index, tz, resultSetSerializable.getFormatDateWithTimeZone());
}
Expand All @@ -507,7 +541,7 @@ public Time getTime(int columnIndex) throws SFException {
ArrowVectorConverter converter = currentChunkIterator.getCurrentConverter(columnIndex - 1);
int index = currentChunkIterator.getCurrentRowInRecordBatch();
wasNull = converter.isNull(index);
converter.setSessionTimeZone(sessionTimezone);
converter.setSessionTimeZone(sessionTimeZone);
converter.setUseSessionTimezone(useSessionTimezone);
return converter.toTime(index);
}
Expand All @@ -516,7 +550,7 @@ public Time getTime(int columnIndex) throws SFException {
public Timestamp getTimestamp(int columnIndex, TimeZone tz) throws SFException {
ArrowVectorConverter converter = currentChunkIterator.getCurrentConverter(columnIndex - 1);
int index = currentChunkIterator.getCurrentRowInRecordBatch();
converter.setSessionTimeZone(sessionTimezone);
converter.setSessionTimeZone(sessionTimeZone);
converter.setUseSessionTimezone(useSessionTimezone);
wasNull = converter.isNull(index);
return converter.toTimestamp(index, tz);
Expand All @@ -529,7 +563,7 @@ public Object getObject(int columnIndex) throws SFException {
wasNull = converter.isNull(index);
converter.setTreatNTZAsUTC(treatNTZAsUTC);
converter.setUseSessionTimezone(useSessionTimezone);
converter.setSessionTimeZone(sessionTimezone);
converter.setSessionTimeZone(sessionTimeZone);
Object obj = converter.toObject(index);
int type = resultSetMetaData.getColumnType(columnIndex);
if (type == Types.STRUCT && StructureTypeHelper.isStructureTypeEnabled()) {
Expand All @@ -550,7 +584,7 @@ private Object createJsonSqlInput(int columnIndex, Object obj) throws SFExceptio
session,
converters,
resultSetMetaData.getColumnMetadata().get(columnIndex - 1).getFields(),
sessionTimezone);
sessionTimeZone);
} catch (JsonProcessingException e) {
throw new SFException(e, ErrorCode.INVALID_STRUCT_DATA);
}
Expand All @@ -566,16 +600,134 @@ private Object createArrowSqlInput(int columnIndex, Map<String, Object> input) {

@Override
public Array getArray(int columnIndex) throws SFException {
// TODO: handleArray SNOW-969794
throw new SFException(ErrorCode.FEATURE_UNSUPPORTED, "data type ARRAY");
ArrowVectorConverter converter = currentChunkIterator.getCurrentConverter(columnIndex - 1);
int index = currentChunkIterator.getCurrentRowInRecordBatch();
wasNull = converter.isNull(index);
Object obj = converter.toObject(index);
if (converter instanceof VarCharConverter) {
return getJsonArray((String) obj, columnIndex);
} else if (converter instanceof ArrayConverter) {
return getArrowArray((List<Object>) obj, columnIndex);
} else {
throw new SFException(ErrorCode.INVALID_STRUCT_DATA);
}
}

private SfSqlArray getArrowArray(List<Object> elements, int columnIndex) throws SFException {
try {
SnowflakeColumnMetadata arrayMetadata =
resultSetMetaData.getColumnMetadata().get(columnIndex - 1);
FieldMetadata fieldMetadata = arrayMetadata.getFields().get(0);

int columnSubType = fieldMetadata.getType();
int columnType = ColumnTypeHelper.getColumnType(columnSubType, session);
int scale = fieldMetadata.getScale();

switch (columnSubType) {
case Types.INTEGER:
return new SfSqlArray(
columnSubType,
mapAndConvert(elements, converters.integerConverter(columnType))
.toArray(Integer[]::new));
case Types.SMALLINT:
return new SfSqlArray(
columnSubType,
mapAndConvert(elements, converters.smallIntConverter(columnType))
.toArray(Short[]::new));
case Types.TINYINT:
return new SfSqlArray(
columnSubType,
mapAndConvert(elements, converters.tinyIntConverter(columnType))
.toArray(Byte[]::new));
case Types.BIGINT:
return new SfSqlArray(
columnSubType,
mapAndConvert(elements, converters.bigIntConverter(columnType)).toArray(Long[]::new));
case Types.DECIMAL:
case Types.NUMERIC:
return new SfSqlArray(
columnSubType,
mapAndConvert(elements, converters.bigDecimalConverter(columnType))
.toArray(BigDecimal[]::new));
case Types.CHAR:
case Types.VARCHAR:
case Types.LONGNVARCHAR:
return new SfSqlArray(
columnSubType,
mapAndConvert(elements, converters.varcharConverter(columnType, columnSubType, scale))
.toArray(String[]::new));
case Types.BINARY:
return new SfSqlArray(
columnSubType,
mapAndConvert(elements, converters.bytesConverter(columnType, scale))
.toArray(Byte[][]::new));
case Types.FLOAT:
case Types.REAL:
return new SfSqlArray(
columnSubType,
mapAndConvert(elements, converters.floatConverter(columnType)).toArray(Float[]::new));
case Types.DOUBLE:
return new SfSqlArray(
columnSubType,
mapAndConvert(elements, converters.doubleConverter(columnType))
.toArray(Double[]::new));
case Types.DATE:
return new SfSqlArray(
columnSubType,
mapAndConvert(elements, converters.dateFromIntConverter(sessionTimeZone))
.toArray(Date[]::new));
case Types.TIME:
return new SfSqlArray(
columnSubType,
mapAndConvert(elements, converters.timeFromIntConverter(scale)).toArray(Time[]::new));
case Types.TIMESTAMP:
return new SfSqlArray(
columnSubType,
mapAndConvert(
elements,
converters.timestampFromStructConverter(
columnType, columnSubType, sessionTimeZone, scale))
.toArray(Timestamp[]::new));
case Types.BOOLEAN:
return new SfSqlArray(
columnSubType,
mapAndConvert(elements, converters.booleanConverter(columnType))
.toArray(Boolean[]::new));
case Types.STRUCT:
return new SfSqlArray(columnSubType, mapAndConvert(elements, e -> e).toArray(Map[]::new));
case Types.ARRAY:
return new SfSqlArray(
columnSubType,
mapAndConvert(elements, e -> ((List) e).stream().toArray(Map[]::new))
.toArray(Map[][]::new));
default:
throw new SFException(
ErrorCode.FEATURE_UNSUPPORTED,
"Can't construct array for data type: " + columnSubType);
}
} catch (RuntimeException e) {
throw new SFException(e, ErrorCode.INVALID_STRUCT_DATA);
}
}

private <T> Stream<T> mapAndConvert(List<Object> elements, Converter<T> converter) {
return elements.stream()
.map(
obj -> {
try {
return converter.convert(obj);
} catch (SFException e) {
throw new RuntimeException(e);
}
});
}

@Override
public BigDecimal getBigDecimal(int columnIndex) throws SFException {
ArrowVectorConverter converter = currentChunkIterator.getCurrentConverter(columnIndex - 1);
int index = currentChunkIterator.getCurrentRowInRecordBatch();
wasNull = converter.isNull(index);
converter.setSessionTimeZone(sessionTimezone);
converter.setSessionTimeZone(sessionTimeZone);
converter.setUseSessionTimezone(useSessionTimezone);
return converter.toBigDecimal(index);
}
Expand Down Expand Up @@ -715,7 +867,7 @@ public int getScale(int columnIndex) {

@Override
public TimeZone getTimeZone() {
return sessionTimezone;
return sessionTimeZone;
}

@Override
Expand Down
Loading

0 comments on commit 43d8fba

Please sign in to comment.