|
96 | 96 | import static com.facebook.presto.tests.StructuralTestUtil.mapType;
|
97 | 97 | import static com.google.common.base.Functions.compose;
|
98 | 98 | import static com.google.common.base.Preconditions.checkArgument;
|
| 99 | +import static com.google.common.collect.ImmutableList.toImmutableList; |
99 | 100 | import static com.google.common.collect.Iterables.concat;
|
100 | 101 | import static com.google.common.collect.Iterables.cycle;
|
101 | 102 | import static com.google.common.collect.Iterables.limit;
|
@@ -181,6 +182,105 @@ public void testNestedArrays()
|
181 | 182 | tester.testRoundTrip(objectInspector, values, values, type);
|
182 | 183 | }
|
183 | 184 |
|
| 185 | + @Test |
| 186 | + public void testNestedArraysDecimalBackedByINT32() |
| 187 | + throws Exception |
| 188 | + { |
| 189 | + int precision = 1; |
| 190 | + int scale = 0; |
| 191 | + ObjectInspector objectInspector = getStandardListObjectInspector(javaIntObjectInspector); |
| 192 | + Type type = new ArrayType(createDecimalType(precision, scale)); |
| 193 | + Iterable<List<Integer>> values = createTestArrays(intsBetween(1, 1_000)); |
| 194 | + |
| 195 | + ImmutableList.Builder<List<SqlDecimal>> expectedValues = new ImmutableList.Builder<>(); |
| 196 | + for (List<Integer> value : values) { |
| 197 | + expectedValues.add(value.stream() |
| 198 | + .map(valueInt -> SqlDecimal.of(valueInt, precision, scale)) |
| 199 | + .collect(toImmutableList())); |
| 200 | + } |
| 201 | + |
| 202 | + MessageType hiveSchema = parseMessageType(format("message hive_list_decimal {" + |
| 203 | + " optional group my_list (LIST){" + |
| 204 | + " repeated group list {" + |
| 205 | + " optional INT32 element (DECIMAL(%d, %d));" + |
| 206 | + " }" + |
| 207 | + " }" + |
| 208 | + "} ", precision, scale)); |
| 209 | + |
| 210 | + tester.testRoundTrip(objectInspector, values, expectedValues.build(), "my_list", type, Optional.of(hiveSchema)); |
| 211 | + } |
| 212 | + |
| 213 | + @Test |
| 214 | + public void testNestedArraysDecimalBackedByINT64() |
| 215 | + throws Exception |
| 216 | + { |
| 217 | + int precision = 10; |
| 218 | + int scale = 2; |
| 219 | + ObjectInspector objectInspector = getStandardListObjectInspector(javaLongObjectInspector); |
| 220 | + Type type = new ArrayType(createDecimalType(precision, scale)); |
| 221 | + Iterable<List<Long>> values = createTestArrays(longsBetween(1, 1_000)); |
| 222 | + |
| 223 | + ImmutableList.Builder<List<SqlDecimal>> expectedValues = new ImmutableList.Builder<>(); |
| 224 | + for (List<Long> value : values) { |
| 225 | + expectedValues.add(value.stream() |
| 226 | + .map(valueLong -> SqlDecimal.of(valueLong, precision, scale)) |
| 227 | + .collect(toImmutableList())); |
| 228 | + } |
| 229 | + |
| 230 | + MessageType hiveSchema = parseMessageType(format("message hive_list_decimal {" + |
| 231 | + " optional group my_list (LIST){" + |
| 232 | + " repeated group list {" + |
| 233 | + " optional INT64 element (DECIMAL(%d, %d));" + |
| 234 | + " }" + |
| 235 | + " }" + |
| 236 | + "} ", precision, scale)); |
| 237 | + tester.testRoundTrip(objectInspector, values, expectedValues.build(), "my_list", type, Optional.of(hiveSchema)); |
| 238 | + } |
| 239 | + |
| 240 | + @Test |
| 241 | + public void testNestedArraysShortDecimalBackedByBinary() |
| 242 | + throws Exception |
| 243 | + { |
| 244 | + int precision = 1; |
| 245 | + int scale = 0; |
| 246 | + ObjectInspector objectInspector = getStandardListObjectInspector(new JavaHiveDecimalObjectInspector(new DecimalTypeInfo(precision, scale))); |
| 247 | + Type type = new ArrayType(createDecimalType(precision, scale)); |
| 248 | + Iterable<List<HiveDecimal>> values = getNestedDecimalArrayInputValues(precision, scale); |
| 249 | + List<List<SqlDecimal>> expectedValues = getNestedDecimalArrayExpectedValues(values, precision, scale); |
| 250 | + |
| 251 | + MessageType hiveSchema = parseMessageType(format("message hive_list_decimal {" + |
| 252 | + " optional group my_list (LIST){" + |
| 253 | + " repeated group list {" + |
| 254 | + " optional BINARY element (DECIMAL(%d, %d));" + |
| 255 | + " }" + |
| 256 | + " }" + |
| 257 | + "} ", precision, scale)); |
| 258 | + |
| 259 | + tester.testRoundTrip(objectInspector, values, expectedValues, "my_list", type, Optional.of(hiveSchema)); |
| 260 | + } |
| 261 | + |
| 262 | + private Iterable<List<HiveDecimal>> getNestedDecimalArrayInputValues(int precision, int scale) |
| 263 | + { |
| 264 | + ContiguousSet<BigInteger> bigIntegerValues = bigIntegersBetween(BigDecimal.valueOf(Math.pow(10, precision - 1)).toBigInteger(), |
| 265 | + BigDecimal.valueOf(Math.pow(10, precision)).toBigInteger()); |
| 266 | + List<HiveDecimal> writeValues = bigIntegerValues.stream() |
| 267 | + .map(value -> HiveDecimal.create((BigInteger) value, scale)) |
| 268 | + .collect(toImmutableList()); |
| 269 | + |
| 270 | + return createTestArrays(writeValues); |
| 271 | + } |
| 272 | + |
| 273 | + private static List<List<SqlDecimal>> getNestedDecimalArrayExpectedValues(Iterable<List<HiveDecimal>> values, int precision, int scale) |
| 274 | + { |
| 275 | + ImmutableList.Builder<List<SqlDecimal>> expectedValues = new ImmutableList.Builder<>(); |
| 276 | + for (List<HiveDecimal> value : values) { |
| 277 | + expectedValues.add(value.stream() |
| 278 | + .map(valueHiveDecimal -> new SqlDecimal(valueHiveDecimal.unscaledValue(), precision, scale)) |
| 279 | + .collect(toImmutableList())); |
| 280 | + } |
| 281 | + return expectedValues.build(); |
| 282 | + } |
| 283 | + |
184 | 284 | @Test
|
185 | 285 | public void testSingleLevelSchemaNestedArrays()
|
186 | 286 | throws Exception
|
|
0 commit comments