-
Notifications
You must be signed in to change notification settings - Fork 3.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for double to varchar coercion in hive tables #18832
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
/* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package io.trino.plugin.hive.coercions; | ||
|
||
import io.airlift.slice.Slice; | ||
import io.airlift.slice.Slices; | ||
import io.trino.spi.TrinoException; | ||
import io.trino.spi.block.Block; | ||
import io.trino.spi.block.BlockBuilder; | ||
import io.trino.spi.type.DoubleType; | ||
import io.trino.spi.type.VarcharType; | ||
|
||
import static io.airlift.slice.SliceUtf8.countCodePoints; | ||
import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; | ||
import static io.trino.spi.type.DoubleType.DOUBLE; | ||
import static io.trino.spi.type.Varchars.truncateToLength; | ||
import static java.lang.String.format; | ||
|
||
public class DoubleToVarcharCoercer | ||
extends TypeCoercer<DoubleType, VarcharType> | ||
{ | ||
public DoubleToVarcharCoercer(VarcharType toType) | ||
{ | ||
super(DOUBLE, toType); | ||
} | ||
|
||
@Override | ||
protected void applyCoercedValue(BlockBuilder blockBuilder, Block block, int position) | ||
{ | ||
double doubleValue = DOUBLE.getDouble(block, position); | ||
Slice converted = Slices.utf8Slice(Double.toString(doubleValue)); | ||
if (!toType.isUnbounded() && countCodePoints(converted) > toType.getBoundedLength()) { | ||
throw new TrinoException(INVALID_ARGUMENTS, format("Varchar representation of %s exceeds %s bounds", doubleValue, toType)); | ||
} | ||
toType.writeSlice(blockBuilder, truncateToLength(converted, toType)); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
/* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package io.trino.plugin.hive.coercions; | ||
|
||
import io.airlift.slice.Slices; | ||
import io.trino.spi.TrinoException; | ||
import io.trino.spi.block.Block; | ||
import io.trino.spi.type.Type; | ||
import org.testng.annotations.DataProvider; | ||
import org.testng.annotations.Test; | ||
|
||
import static io.trino.plugin.hive.HiveTimestampPrecision.DEFAULT_PRECISION; | ||
import static io.trino.plugin.hive.HiveType.toHiveType; | ||
import static io.trino.plugin.hive.coercions.CoercionUtils.createCoercer; | ||
import static io.trino.spi.predicate.Utils.blockToNativeValue; | ||
import static io.trino.spi.predicate.Utils.nativeValueToBlock; | ||
import static io.trino.spi.type.DoubleType.DOUBLE; | ||
import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; | ||
import static io.trino.spi.type.VarcharType.createVarcharType; | ||
import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; | ||
import static org.assertj.core.api.Assertions.assertThat; | ||
import static org.assertj.core.api.Assertions.assertThatThrownBy; | ||
|
||
public class TestDoubleToVarcharCoercions | ||
{ | ||
@Test(dataProvider = "doubleValues") | ||
public void testNaNToVarcharCoercions(Double doubleValue) | ||
{ | ||
assertCoercions(DOUBLE, doubleValue, createUnboundedVarcharType(), Slices.utf8Slice(doubleValue.toString())); | ||
} | ||
|
||
@Test(dataProvider = "doubleValues") | ||
public void testDoubleSmallerVarcharCoercions(Double doubleValue) | ||
{ | ||
assertThatThrownBy(() -> assertCoercions(DOUBLE, doubleValue, createVarcharType(1), doubleValue.toString())) | ||
.isInstanceOf(TrinoException.class) | ||
.hasMessageContaining("Varchar representation of %s exceeds varchar(1) bounds", doubleValue); | ||
} | ||
|
||
@DataProvider | ||
public Object[][] doubleValues() | ||
{ | ||
return new Object[][] { | ||
{Double.MAX_VALUE}, | ||
{Double.MAX_VALUE}, | ||
{Double.parseDouble("123456789.12345678")}, | ||
{Double.NaN}, | ||
}; | ||
} | ||
|
||
public static void assertCoercions(Type fromType, Object valueToBeCoerced, Type toType, Object expectedValue) | ||
{ | ||
Block coercedValue = createCoercer(TESTING_TYPE_MANAGER, toHiveType(fromType), toHiveType(toType), DEFAULT_PRECISION).orElseThrow() | ||
.apply(nativeValueToBlock(fromType, valueToBeCoerced)); | ||
assertThat(blockToNativeValue(toType, coercedValue)) | ||
.isEqualTo(expectedValue); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -105,6 +105,8 @@ protected void doTestHiveCoercion(HiveTableDefinition tableDefinition) | |
"bigint_to_varchar", | ||
"float_to_double", | ||
"double_to_float", | ||
"double_to_string", | ||
"double_to_bounded_varchar", | ||
"shortdecimal_to_shortdecimal", | ||
"shortdecimal_to_longdecimal", | ||
"longdecimal_to_shortdecimal", | ||
|
@@ -166,6 +168,8 @@ protected void insertTableRows(String tableName, String floatToDoubleType) | |
" 12345, " + | ||
" REAL '0.5', " + | ||
" DOUBLE '0.5', " + | ||
" DOUBLE '12345.12345', " + | ||
" DOUBLE '12345.12345', " + | ||
" DECIMAL '12345678.12', " + | ||
" DECIMAL '12345678.12', " + | ||
" DECIMAL '12345678.123456123456', " + | ||
|
@@ -201,6 +205,8 @@ protected void insertTableRows(String tableName, String floatToDoubleType) | |
" -12345, " + | ||
" REAL '-1.5', " + | ||
" DOUBLE '-1.5', " + | ||
" DOUBLE 'NaN', " + | ||
" DOUBLE '-12345.12345', " + | ||
" DECIMAL '-12345678.12', " + | ||
" DECIMAL '-12345678.12', " + | ||
" DECIMAL '-12345678.123456123456', " + | ||
|
@@ -230,6 +236,7 @@ protected void insertTableRows(String tableName, String floatToDoubleType) | |
protected Map<String, List<Object>> expectedValuesForEngineProvider(Engine engine, String tableName, String decimalToFloatVal, String floatToDecimalVal) | ||
{ | ||
String hiveValueForCaseChangeField; | ||
String coercedNaN = "NaN"; | ||
Predicate<String> isFormat = formatName -> tableName.toLowerCase(ENGLISH).contains(formatName); | ||
if (isFormat.test("rctext") || isFormat.test("textfile")) { | ||
hiveValueForCaseChangeField = "\"lower2uppercase\":2"; | ||
|
@@ -241,6 +248,11 @@ else if (getHiveVersionMajor() == 3 && isFormat.test("orc")) { | |
hiveValueForCaseChangeField = "\"LOWER2UPPERCASE\":2"; | ||
} | ||
|
||
// For ORC when we coerce NaN to String, it returns coerced value as `null` | ||
if (isFormat.test("orc") && engine == Engine.HIVE) { | ||
coercedNaN = null; | ||
} | ||
|
||
return ImmutableMap.<String, List<Object>>builder() | ||
.put("row_to_row", ImmutableList.of( | ||
engine == Engine.TRINO ? | ||
|
@@ -321,6 +333,8 @@ else if (getHiveVersionMajor() == 3 && isFormat.test("orc")) { | |
0.5, | ||
-1.5)) | ||
.put("double_to_float", ImmutableList.of(0.5, -1.5)) | ||
.put("double_to_string", Arrays.asList("12345.12345", coercedNaN)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ImmutableList too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, got it. |
||
.put("double_to_bounded_varchar", ImmutableList.of("12345.12345", "-12345.12345")) | ||
.put("shortdecimal_to_shortdecimal", ImmutableList.of( | ||
new BigDecimal("12345678.1200"), | ||
new BigDecimal("-12345678.1200"))) | ||
|
@@ -750,6 +764,8 @@ private void assertProperAlteredTableSchema(String tableName) | |
row("bigint_to_varchar", "varchar"), | ||
row("float_to_double", "double"), | ||
row("double_to_float", floatType), | ||
row("double_to_string", "varchar"), | ||
row("double_to_bounded_varchar", "varchar(12)"), | ||
row("shortdecimal_to_shortdecimal", "decimal(18,4)"), | ||
row("shortdecimal_to_longdecimal", "decimal(20,4)"), | ||
row("longdecimal_to_shortdecimal", "decimal(12,2)"), | ||
|
@@ -801,6 +817,8 @@ private void assertColumnTypes( | |
.put("bigint_to_varchar", VARCHAR) | ||
.put("float_to_double", DOUBLE) | ||
.put("double_to_float", floatType) | ||
.put("double_to_string", VARCHAR) | ||
.put("double_to_bounded_varchar", VARCHAR) | ||
.put("shortdecimal_to_shortdecimal", DECIMAL) | ||
.put("shortdecimal_to_longdecimal", DECIMAL) | ||
.put("longdecimal_to_shortdecimal", DECIMAL) | ||
|
@@ -851,6 +869,8 @@ private static void alterTableColumnTypes(String tableName) | |
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN bigint_to_varchar bigint_to_varchar string", tableName)); | ||
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN float_to_double float_to_double double", tableName)); | ||
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN double_to_float double_to_float %s", tableName, floatType)); | ||
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN double_to_string double_to_string string", tableName)); | ||
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN double_to_bounded_varchar double_to_bounded_varchar varchar(12)", tableName)); | ||
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN shortdecimal_to_shortdecimal shortdecimal_to_shortdecimal DECIMAL(18,4)", tableName)); | ||
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN shortdecimal_to_longdecimal shortdecimal_to_longdecimal DECIMAL(20,4)", tableName)); | ||
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN longdecimal_to_shortdecimal longdecimal_to_shortdecimal DECIMAL(12,2)", tableName)); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,6 +61,8 @@ private static HiveTableDefinition.HiveTableDefinitionBuilder tableDefinitionBui | |
bigint_to_varchar BIGINT, | ||
float_to_double FLOAT, | ||
double_to_float DOUBLE, | ||
double_to_string DOUBLE, | ||
double_to_bounded_varchar DOUBLE, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about a test for out of bounds? |
||
shortdecimal_to_shortdecimal DECIMAL(10,2), | ||
shortdecimal_to_longdecimal DECIMAL(10,2), | ||
longdecimal_to_shortdecimal DECIMAL(20,12), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a bug that we should track?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is how ORC file behaves, I'm not sure if we need to track them in ORC project