Skip to content

Commit 7d61625

Browse files
Avoid batch reading of nested decimals as this reader is not implemented
1 parent e223cc1 commit 7d61625

File tree

2 files changed

+102
-2
lines changed

2 files changed

+102
-2
lines changed

presto-hive/src/test/java/com/facebook/presto/hive/parquet/AbstractTestParquetReader.java

+100
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
import static com.facebook.presto.tests.StructuralTestUtil.mapType;
9797
import static com.google.common.base.Functions.compose;
9898
import static com.google.common.base.Preconditions.checkArgument;
99+
import static com.google.common.collect.ImmutableList.toImmutableList;
99100
import static com.google.common.collect.Iterables.concat;
100101
import static com.google.common.collect.Iterables.cycle;
101102
import static com.google.common.collect.Iterables.limit;
@@ -181,6 +182,105 @@ public void testNestedArrays()
181182
tester.testRoundTrip(objectInspector, values, values, type);
182183
}
183184

185+
@Test
186+
public void testNestedArraysDecimalBackedByINT32()
187+
throws Exception
188+
{
189+
int precision = 1;
190+
int scale = 0;
191+
ObjectInspector objectInspector = getStandardListObjectInspector(javaIntObjectInspector);
192+
Type type = new ArrayType(createDecimalType(precision, scale));
193+
Iterable<List<Integer>> values = createTestArrays(intsBetween(1, 1_000));
194+
195+
ImmutableList.Builder<List<SqlDecimal>> expectedValues = new ImmutableList.Builder<>();
196+
for (List<Integer> value : values) {
197+
expectedValues.add(value.stream()
198+
.map(valueInt -> SqlDecimal.of(valueInt, precision, scale))
199+
.collect(toImmutableList()));
200+
}
201+
202+
MessageType hiveSchema = parseMessageType(format("message hive_list_decimal {" +
203+
" optional group my_list (LIST){" +
204+
" repeated group list {" +
205+
" optional INT32 element (DECIMAL(%d, %d));" +
206+
" }" +
207+
" }" +
208+
"} ", precision, scale));
209+
210+
tester.testRoundTrip(objectInspector, values, expectedValues.build(), "my_list", type, Optional.of(hiveSchema));
211+
}
212+
213+
@Test
214+
public void testNestedArraysDecimalBackedByINT64()
215+
throws Exception
216+
{
217+
int precision = 10;
218+
int scale = 2;
219+
ObjectInspector objectInspector = getStandardListObjectInspector(javaLongObjectInspector);
220+
Type type = new ArrayType(createDecimalType(precision, scale));
221+
Iterable<List<Long>> values = createTestArrays(longsBetween(1, 1_000));
222+
223+
ImmutableList.Builder<List<SqlDecimal>> expectedValues = new ImmutableList.Builder<>();
224+
for (List<Long> value : values) {
225+
expectedValues.add(value.stream()
226+
.map(valueLong -> SqlDecimal.of(valueLong, precision, scale))
227+
.collect(toImmutableList()));
228+
}
229+
230+
MessageType hiveSchema = parseMessageType(format("message hive_list_decimal {" +
231+
" optional group my_list (LIST){" +
232+
" repeated group list {" +
233+
" optional INT64 element (DECIMAL(%d, %d));" +
234+
" }" +
235+
" }" +
236+
"} ", precision, scale));
237+
tester.testRoundTrip(objectInspector, values, expectedValues.build(), "my_list", type, Optional.of(hiveSchema));
238+
}
239+
240+
@Test
241+
public void testNestedArraysShortDecimalBackedByBinary()
242+
throws Exception
243+
{
244+
int precision = 1;
245+
int scale = 0;
246+
ObjectInspector objectInspector = getStandardListObjectInspector(new JavaHiveDecimalObjectInspector(new DecimalTypeInfo(precision, scale)));
247+
Type type = new ArrayType(createDecimalType(precision, scale));
248+
Iterable<List<HiveDecimal>> values = getNestedDecimalArrayInputValues(precision, scale);
249+
List<List<SqlDecimal>> expectedValues = getNestedDecimalArrayExpectedValues(values, precision, scale);
250+
251+
MessageType hiveSchema = parseMessageType(format("message hive_list_decimal {" +
252+
" optional group my_list (LIST){" +
253+
" repeated group list {" +
254+
" optional BINARY element (DECIMAL(%d, %d));" +
255+
" }" +
256+
" }" +
257+
"} ", precision, scale));
258+
259+
tester.testRoundTrip(objectInspector, values, expectedValues, "my_list", type, Optional.of(hiveSchema));
260+
}
261+
262+
private Iterable<List<HiveDecimal>> getNestedDecimalArrayInputValues(int precision, int scale)
263+
{
264+
ContiguousSet<BigInteger> bigIntegerValues = bigIntegersBetween(BigDecimal.valueOf(Math.pow(10, precision - 1)).toBigInteger(),
265+
BigDecimal.valueOf(Math.pow(10, precision)).toBigInteger());
266+
List<HiveDecimal> writeValues = bigIntegerValues.stream()
267+
.map(value -> HiveDecimal.create((BigInteger) value, scale))
268+
.collect(toImmutableList());
269+
270+
return createTestArrays(writeValues);
271+
}
272+
273+
private static List<List<SqlDecimal>> getNestedDecimalArrayExpectedValues(Iterable<List<HiveDecimal>> values, int precision, int scale)
274+
{
275+
ImmutableList.Builder<List<SqlDecimal>> expectedValues = new ImmutableList.Builder<>();
276+
for (List<HiveDecimal> value : values) {
277+
expectedValues.add(value.stream()
278+
.map(valueHiveDecimal -> new SqlDecimal(valueHiveDecimal.unscaledValue(), precision, scale))
279+
.collect(toImmutableList()));
280+
}
281+
return expectedValues.build();
282+
}
283+
184284
@Test
185285
public void testSingleLevelSchemaNestedArrays()
186286
throws Exception

presto-parquet/src/main/java/com/facebook/presto/parquet/ColumnReaderFactory.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ private ColumnReaderFactory()
6363

6464
public static ColumnReader createReader(RichColumnDescriptor descriptor, boolean batchReadEnabled)
6565
{
66-
if (batchReadEnabled) {
67-
final boolean isNested = descriptor.getPath().length > 1;
66+
final boolean isNested = descriptor.getPath().length > 1;
67+
if (batchReadEnabled && (!(isNested && isDecimalType(descriptor)))) {
6868
switch (descriptor.getPrimitiveType().getPrimitiveTypeName()) {
6969
case BOOLEAN:
7070
return isNested ? new BooleanNestedBatchReader(descriptor) : new BooleanFlatBatchReader(descriptor);

0 commit comments

Comments
 (0)