Skip to content

Commit

Permalink
modernizer for presto-orc (#23991)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZacBlanco committed Nov 11, 2024
1 parent bb10ddc commit bb9d4e3
Show file tree
Hide file tree
Showing 57 changed files with 371 additions and 319 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,10 @@ private static void includeOrcColumnsRecursive(List<OrcType> types, Set<Integer>
List<Subfield> subfields = ImmutableList.of();
if (requiredFields.isPresent()) {
String fieldName = type.getFieldNames().get(i).toLowerCase(Locale.ENGLISH);
if (!requiredFields.get().containsKey(fieldName)) {
if (!requiredFields.orElseThrow().containsKey(fieldName)) {
continue;
}
subfields = requiredFields.get().get(fieldName);
subfields = requiredFields.orElseThrow().get(fieldName);
}

includeOrcColumnsRecursive(types, result, type.getFieldTypeIndex(i), subfields);
Expand Down Expand Up @@ -436,7 +436,7 @@ private static boolean isStripeIncluded(
if (!stripeStats.isPresent()) {
return true;
}
return predicate.matches(stripe.getNumberOfRows(), getStatisticsByColumnOrdinal(rootStructType, stripeStats.get().getColumnStatistics()));
return predicate.matches(stripe.getNumberOfRows(), getStatisticsByColumnOrdinal(rootStructType, stripeStats.orElseThrow().getColumnStatistics()));
}

@VisibleForTesting
Expand Down Expand Up @@ -511,7 +511,7 @@ public void close()
}
rowGroups = null;
if (writeChecksumBuilder.isPresent()) {
OrcWriteValidation.WriteChecksum actualChecksum = writeChecksumBuilder.get().build();
OrcWriteValidation.WriteChecksum actualChecksum = writeChecksumBuilder.orElseThrow().build();
validateWrite(validation -> validation.getChecksum().getTotalRowCount() == actualChecksum.getTotalRowCount(), "Invalid row count");
List<Long> columnHashes = actualChecksum.getColumnHashes();
for (int i = 0; i < columnHashes.size(); i++) {
Expand All @@ -522,8 +522,8 @@ public void close()
validateWrite(validation -> validation.getChecksum().getStripeHash() == actualChecksum.getStripeHash(), "Invalid stripes checksum");
}
if (fileStatisticsValidation.isPresent()) {
List<ColumnStatistics> columnStatistics = fileStatisticsValidation.get().build();
writeValidation.get().validateFileStatistics(orcDataSource.getId(), columnStatistics);
List<ColumnStatistics> columnStatistics = fileStatisticsValidation.orElseThrow().build();
writeValidation.orElseThrow().validateFileStatistics(orcDataSource.getId(), columnStatistics);
}
}

Expand All @@ -544,9 +544,9 @@ private boolean advanceToNextRowGroup()

if (currentRowGroup >= 0) {
if (rowGroupStatisticsValidation.isPresent()) {
OrcWriteValidation.StatisticsValidation statisticsValidation = rowGroupStatisticsValidation.get();
OrcWriteValidation.StatisticsValidation statisticsValidation = rowGroupStatisticsValidation.orElseThrow();
long offset = stripes.get(currentStripe).getOffset();
writeValidation.get().validateRowGroupStatistics(orcDataSource.getId(), offset, currentRowGroup, statisticsValidation.build());
writeValidation.orElseThrow().validateRowGroupStatistics(orcDataSource.getId(), offset, currentRowGroup, statisticsValidation.build());
statisticsValidation.reset();
}
}
Expand Down Expand Up @@ -637,9 +637,9 @@ private void advanceToNextStripe()

if (currentStripe >= 0) {
if (stripeStatisticsValidation.isPresent()) {
OrcWriteValidation.StatisticsValidation statisticsValidation = stripeStatisticsValidation.get();
OrcWriteValidation.StatisticsValidation statisticsValidation = stripeStatisticsValidation.orElseThrow();
long offset = stripes.get(currentStripe).getOffset();
writeValidation.get().validateStripeStatistics(orcDataSource.getId(), offset, statisticsValidation.build());
writeValidation.orElseThrow().validateStripeStatistics(orcDataSource.getId(), offset, statisticsValidation.build());
statisticsValidation.reset();
}
}
Expand All @@ -661,9 +661,9 @@ private void advanceToNextStripe()
// or it has been set, but we have new decryption keys,
// set dwrfEncryptionInfo
if ((!stripeDecryptionKeyMetadata.isEmpty() && !dwrfEncryptionInfo.isPresent())
|| (dwrfEncryptionInfo.isPresent() && !stripeDecryptionKeyMetadata.equals(dwrfEncryptionInfo.get().getEncryptedKeyMetadatas()))) {
|| (dwrfEncryptionInfo.isPresent() && !stripeDecryptionKeyMetadata.equals(dwrfEncryptionInfo.orElseThrow().getEncryptedKeyMetadatas()))) {
verify(encryptionLibrary.isPresent(), "encryptionLibrary is absent");
dwrfEncryptionInfo = Optional.of(createDwrfEncryptionInfo(encryptionLibrary.get(), stripeDecryptionKeyMetadata, intermediateKeyMetadata, dwrfEncryptionGroupMap));
dwrfEncryptionInfo = Optional.of(createDwrfEncryptionInfo(encryptionLibrary.orElseThrow(), stripeDecryptionKeyMetadata, intermediateKeyMetadata, dwrfEncryptionGroupMap));
}

SharedBuffer sharedDecompressionBuffer = new SharedBuffer(currentStripeSystemMemoryContext.newOrcLocalMemoryContext("sharedDecompressionBuffer"));
Expand Down Expand Up @@ -698,15 +698,15 @@ public static List<byte[]> getDecryptionKeyMetadata(int currentStripe, List<Stri
private void validateWrite(Predicate<OrcWriteValidation> test, String messageFormat, Object... args)
throws OrcCorruptionException
{
if (writeValidation.isPresent() && !test.apply(writeValidation.get())) {
if (writeValidation.isPresent() && !test.apply(writeValidation.orElseThrow())) {
throw new OrcCorruptionException(orcDataSource.getId(), "Write validation failed: " + messageFormat, args);
}
}

private void validateWriteStripe(long rowCount)
{
if (writeChecksumBuilder.isPresent()) {
writeChecksumBuilder.get().addStripe(rowCount);
writeChecksumBuilder.orElseThrow().addStripe(rowCount);
}
}

Expand Down Expand Up @@ -775,10 +775,10 @@ protected boolean shouldValidateWritePageChecksum()
protected void validateWritePageChecksum(Page page)
{
if (writeChecksumBuilder.isPresent()) {
writeChecksumBuilder.get().addPage(page);
rowGroupStatisticsValidation.get().addPage(page);
stripeStatisticsValidation.get().addPage(page);
fileStatisticsValidation.get().addPage(page);
writeChecksumBuilder.orElseThrow().addPage(page);
rowGroupStatisticsValidation.orElseThrow().addPage(page);
stripeStatisticsValidation.orElseThrow().addPage(page);
fileStatisticsValidation.orElseThrow().addPage(page);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public List<RowGroupIndex> getRowIndexes(
throws IOException
{
if (rowGroupIndexCache.isPresent()) {
List<RowGroupIndex> rowGroupIndices = rowGroupIndexCache.get().getIfPresent(new StripeStreamId(stripId, streamId));
List<RowGroupIndex> rowGroupIndices = rowGroupIndexCache.orElseThrow().getIfPresent(new StripeStreamId(stripId, streamId));
if (rowGroupIndices != null) {
runtimeStats.addMetricValue("OrcRowGroupIndexCacheHit", NONE, 1);
runtimeStats.addMetricValue("OrcRowGroupIndexInMemoryBytesRead", BYTE, rowGroupIndices.stream().mapToLong(RowGroupIndex::getRetainedSizeInBytes).sum());
Expand All @@ -143,7 +143,7 @@ public List<RowGroupIndex> getRowIndexes(
}
List<RowGroupIndex> rowGroupIndices = delegate.getRowIndexes(metadataReader, hiveWriterVersion, stripId, streamId, inputStream, bloomFilters, runtimeStats);
if (rowGroupIndexCache.isPresent()) {
rowGroupIndexCache.get().put(new StripeStreamId(stripId, streamId), rowGroupIndices);
rowGroupIndexCache.orElseThrow().put(new StripeStreamId(stripId, streamId), rowGroupIndices);
}
return rowGroupIndices;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public Slice getStripeFooterSlice(OrcDataSource orcDataSource, StripeId stripeId
{
Optional<Slice> stripeFooterSlice = stripeCache.getStripeFooterSlice(stripeId, footerLength);
if (stripeFooterSlice.isPresent()) {
return stripeFooterSlice.get();
return stripeFooterSlice.orElseThrow();
}
return delegate.getStripeFooterSlice(orcDataSource, stripeId, footerOffset, footerLength, cacheable);
}
Expand All @@ -71,7 +71,7 @@ public Map<StreamId, OrcDataSourceInput> getInputs(OrcDataSource orcDataSource,
return delegate.getInputs(orcDataSource, stripeId, diskRanges, cacheable);
}

Slice cacheSlice = stripeCacheIndexStreamsSlice.get();
Slice cacheSlice = stripeCacheIndexStreamsSlice.orElseThrow();
ImmutableMap.Builder<StreamId, OrcDataSourceInput> inputsBuilder = ImmutableMap.builder();
ImmutableMap.Builder<StreamId, DiskRange> dataStreamsBuilder = ImmutableMap.builder();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public StripeMetadataSource create(Optional<DwrfStripeCache> dwrfStripeCache)
{
StripeMetadataSource delegate = requireNonNull(delegateFactory.create(dwrfStripeCache), "created delegate is null");
if (dwrfStripeCache.isPresent()) {
return new DwrfAwareStripeMetadataSource(delegate, dwrfStripeCache.get());
return new DwrfAwareStripeMetadataSource(delegate, dwrfStripeCache.orElseThrow());
}
return delegate;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ private void writeChunkToOutputStream(byte[] chunk, int offset, int length)
}
}
if (dwrfEncryptor.isPresent()) {
chunk = dwrfEncryptor.get().encrypt(chunk, offset, length);
chunk = dwrfEncryptor.orElseThrow().encrypt(chunk, offset, length);
length = chunk.length;
offset = 0;
// size after encryption should not exceed what the 3 byte header can hold (2^23)
Expand Down
18 changes: 9 additions & 9 deletions presto-orc/src/main/java/com/facebook/presto/orc/OrcReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,11 @@ public OrcReader(
requireNonNull(dwrfKeyProvider, "dwrfKeyProvider is null");
validateEncryption(footer, this.orcDataSource.getId());
this.dwrfEncryptionGroupMap = createNodeToGroupMap(
encryption.get().getEncryptionGroups().stream()
encryption.orElseThrow().getEncryptionGroups().stream()
.map(EncryptionGroup::getNodes)
.collect(toImmutableList()),
footer.getTypes());
this.encryptionLibrary = Optional.of(dwrfEncryptionProvider.getEncryptionLibrary(encryption.get().getKeyProvider()));
this.encryptionLibrary = Optional.of(dwrfEncryptionProvider.getEncryptionLibrary(encryption.orElseThrow().getKeyProvider()));
this.columnsToIntermediateKeys = ImmutableMap.copyOf(dwrfKeyProvider.getIntermediateKeys(footer.getTypes()));
}
else {
Expand All @@ -228,17 +228,17 @@ public OrcReader(
validateWrite(writeValidation, orcDataSource, validation -> validation.getColumnNames().equals(footer.getTypes().get(0).getFieldNames()), "Unexpected column names");
validateWrite(writeValidation, orcDataSource, validation -> validation.getRowGroupMaxRowCount() == footer.getRowsInRowGroup(), "Unexpected rows in group");
if (writeValidation.isPresent()) {
writeValidation.get().validateMetadata(orcDataSource.getId(), footer.getUserMetadata());
writeValidation.get().validateFileStatistics(orcDataSource.getId(), footer.getFileStats());
writeValidation.get().validateStripeStatistics(orcDataSource.getId(), footer.getStripes(), metadata.getStripeStatsList());
writeValidation.orElseThrow().validateMetadata(orcDataSource.getId(), footer.getUserMetadata());
writeValidation.orElseThrow().validateFileStatistics(orcDataSource.getId(), footer.getFileStats());
writeValidation.orElseThrow().validateStripeStatistics(orcDataSource.getId(), footer.getStripes(), metadata.getStripeStatsList());
}

this.cacheable = requireNonNull(cacheable, "cacheable is null");

Optional<DwrfStripeCache> dwrfStripeCache = Optional.empty();
if (orcFileTail.getDwrfStripeCacheData().isPresent() && footer.getDwrfStripeCacheOffsets().isPresent()) {
DwrfStripeCacheData dwrfStripeCacheData = orcFileTail.getDwrfStripeCacheData().get();
DwrfStripeCache cache = dwrfStripeCacheData.buildDwrfStripeCache(footer.getStripes(), footer.getDwrfStripeCacheOffsets().get());
DwrfStripeCacheData dwrfStripeCacheData = orcFileTail.getDwrfStripeCacheData().orElseThrow();
DwrfStripeCache cache = dwrfStripeCacheData.buildDwrfStripeCache(footer.getStripes(), footer.getDwrfStripeCacheOffsets().orElseThrow());
dwrfStripeCache = Optional.of(cache);
}

Expand All @@ -252,7 +252,7 @@ public static void validateEncryption(Footer footer, OrcDataSourceId dataSourceI
if (!footer.getEncryption().isPresent()) {
return;
}
DwrfEncryption dwrfEncryption = footer.getEncryption().get();
DwrfEncryption dwrfEncryption = footer.getEncryption().orElseThrow();
int encryptionGroupSize = dwrfEncryption.getEncryptionGroups().size();
List<StripeInformation> stripes = footer.getStripes();
if (!stripes.isEmpty() && encryptionGroupSize > 0 && stripes.get(0).getKeyMetadata().isEmpty()) {
Expand Down Expand Up @@ -483,7 +483,7 @@ static void validateFile(
public static void validateWrite(Optional<OrcWriteValidation> writeValidation, OrcDataSource orcDataSource, Predicate<OrcWriteValidation> test, String messageFormat, Object... args)
throws OrcCorruptionException
{
if (writeValidation.isPresent() && !test.test(writeValidation.get())) {
if (writeValidation.isPresent() && !test.test(writeValidation.orElseThrow())) {
throw new OrcCorruptionException(orcDataSource.getId(), "Write validation failed: " + messageFormat, args);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import com.facebook.presto.orc.reader.SelectiveStreamReader;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import com.google.common.primitives.Ints;
import io.airlift.slice.Slice;
Expand Down Expand Up @@ -84,6 +83,7 @@
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.MoreCollectors.onlyElement;
import static io.airlift.slice.SizeOf.sizeOf;
import static java.lang.Math.max;
import static java.lang.Math.min;
Expand Down Expand Up @@ -470,7 +470,7 @@ private static Optional<FilterFunction> getFilterFunctionWithoutInputs(List<Filt
return Optional.empty();
}

return Optional.of(Iterables.getOnlyElement(functions));
return Optional.of(functions.stream().collect(onlyElement()));
}

private static boolean containsNonNullFilter(Map<Subfield, TupleDomainFilter> columnFilters)
Expand All @@ -487,7 +487,7 @@ private static int scoreFilter(Map<Subfield, TupleDomainFilter> filters)
return 1000;
}

Map.Entry<Subfield, TupleDomainFilter> filterEntry = Iterables.getOnlyElement(filters.entrySet());
Map.Entry<Subfield, TupleDomainFilter> filterEntry = filters.entrySet().stream().collect(onlyElement());
if (!filterEntry.getKey().getPath().isEmpty()) {
// Complex type column. Complex types are expensive!
return 1000;
Expand Down Expand Up @@ -785,7 +785,7 @@ private int applyFilterFunctionWithNoInputs(int positionCount)
{
initializeOutputPositions(positionCount);
Page page = new Page(positionCount);
return filterFunctionWithoutInput.get().filter(page, outputPositions, positionCount, errors);
return filterFunctionWithoutInput.orElseThrow().filter(page, outputPositions, positionCount, errors);
}

private int applyFilterFunctions(List<FilterFunctionWithStats> filterFunctions, Set<Integer> filterFunctionInputs, int[] positions, int positionCount)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.collect.Iterables;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.slice.XxHash64;
Expand Down Expand Up @@ -94,6 +93,7 @@
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.MoreCollectors.onlyElement;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -846,7 +846,7 @@ else if (type instanceof DecimalType) {
else if (type.getTypeSignature().getBase().equals(ARRAY)) {
statisticsBuilder = new CountStatisticsBuilder();
fieldExtractor = block -> ImmutableList.of(toColumnarArray(block).getElementsBlock());
fieldBuilders = ImmutableList.of(new ColumnStatisticsValidation(Iterables.getOnlyElement(type.getTypeParameters())));
fieldBuilders = ImmutableList.of(new ColumnStatisticsValidation(type.getTypeParameters().stream().collect(onlyElement())));
}
else if (type.getTypeSignature().getBase().equals(MAP)) {
statisticsBuilder = new CountStatisticsBuilder();
Expand Down
Loading

0 comments on commit bb9d4e3

Please sign in to comment.