Skip to content

Commit

Permalink
Invalidate ORC metadata cache based on file modification time (presto…
Browse files Browse the repository at this point in the history
  • Loading branch information
nmahadevuni authored Jan 23, 2025
1 parent 15fd5d8 commit cf19b92
Show file tree
Hide file tree
Showing 25 changed files with 274 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,10 @@
import com.google.inject.Scopes;
import com.google.inject.TypeLiteral;
import com.google.inject.multibindings.Multibinder;
import io.airlift.slice.Slice;
import org.weakref.jmx.MBeanExporter;

import javax.inject.Singleton;

import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.function.Supplier;
Expand All @@ -101,6 +99,8 @@
import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder;
import static com.facebook.airlift.json.smile.SmileCodecBinder.smileCodecBinder;
import static com.facebook.drift.codec.guice.ThriftCodecBinder.thriftCodecBinder;
import static com.facebook.presto.orc.StripeMetadataSource.CacheableRowGroupIndices;
import static com.facebook.presto.orc.StripeMetadataSource.CacheableSlice;
import static com.google.common.util.concurrent.MoreExecutors.listeningDecorator;
import static com.google.inject.multibindings.Multibinder.newSetBinder;
import static java.lang.Math.toIntExact;
Expand Down Expand Up @@ -320,15 +320,15 @@ public StripeMetadataSourceFactory createStripeMetadataSourceFactory(OrcCacheCon
{
StripeMetadataSource stripeMetadataSource = new StorageStripeMetadataSource();
if (orcCacheConfig.isStripeMetadataCacheEnabled()) {
Cache<StripeId, Slice> footerCache = CacheBuilder.newBuilder()
Cache<StripeId, CacheableSlice> footerCache = CacheBuilder.newBuilder()
.maximumWeight(orcCacheConfig.getStripeFooterCacheSize().toBytes())
.weigher((id, footer) -> toIntExact(((Slice) footer).getRetainedSize()))
.weigher((id, footer) -> toIntExact(((CacheableSlice) footer).getSlice().getRetainedSize()))
.expireAfterAccess(orcCacheConfig.getStripeFooterCacheTtlSinceLastAccess().toMillis(), MILLISECONDS)
.recordStats()
.build();
Cache<StripeStreamId, Slice> streamCache = CacheBuilder.newBuilder()
Cache<StripeStreamId, CacheableSlice> streamCache = CacheBuilder.newBuilder()
.maximumWeight(orcCacheConfig.getStripeStreamCacheSize().toBytes())
.weigher((id, stream) -> toIntExact(((Slice) stream).getRetainedSize()))
.weigher((id, stream) -> toIntExact(((CacheableSlice) stream).getSlice().getRetainedSize()))
.expireAfterAccess(orcCacheConfig.getStripeStreamCacheTtlSinceLastAccess().toMillis(), MILLISECONDS)
.recordStats()
.build();
Expand All @@ -337,11 +337,11 @@ public StripeMetadataSourceFactory createStripeMetadataSourceFactory(OrcCacheCon
exporter.export(generatedNameOf(CacheStatsMBean.class, connectorId + "_StripeFooter"), footerCacheStatsMBean);
exporter.export(generatedNameOf(CacheStatsMBean.class, connectorId + "_StripeStream"), streamCacheStatsMBean);

Optional<Cache<StripeStreamId, List<RowGroupIndex>>> rowGroupIndexCache = Optional.empty();
Optional<Cache<StripeStreamId, CacheableRowGroupIndices>> rowGroupIndexCache = Optional.empty();
if (orcCacheConfig.isRowGroupIndexCacheEnabled()) {
rowGroupIndexCache = Optional.of(CacheBuilder.newBuilder()
.maximumWeight(orcCacheConfig.getRowGroupIndexCacheSize().toBytes())
.weigher((id, rowGroupIndices) -> toIntExact(((List<RowGroupIndex>) rowGroupIndices).stream().mapToLong(RowGroupIndex::getRetainedSizeInBytes).sum()))
.weigher((id, rowGroupIndices) -> toIntExact(((CacheableRowGroupIndices) rowGroupIndices).getRowGroupIndices().stream().mapToLong(RowGroupIndex::getRetainedSizeInBytes).sum()))
.expireAfterAccess(orcCacheConfig.getStripeStreamCacheTtlSinceLastAccess().toMillis(), MILLISECONDS)
.recordStats()
.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ public static OrcReader getOrcReader(OrcEncoding orcEncoding, List<HiveColumnHan
hiveFileContext.isCacheable(),
dwrfEncryptionProvider,
dwrfKeyProvider,
hiveFileContext.getStats());
hiveFileContext.getStats(),
hiveFileContext.getModificationTime());
return reader;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,18 @@
import com.google.inject.Provides;
import com.google.inject.Scopes;
import com.google.inject.multibindings.Multibinder;
import io.airlift.slice.Slice;
import org.weakref.jmx.MBeanExporter;

import javax.inject.Singleton;

import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutorService;

import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed;
import static com.facebook.airlift.configuration.ConfigBinder.configBinder;
import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder;
import static com.facebook.presto.orc.StripeMetadataSource.CacheableRowGroupIndices;
import static com.facebook.presto.orc.StripeMetadataSource.CacheableSlice;
import static com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService;
import static com.google.inject.multibindings.Multibinder.newSetBinder;
import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder;
Expand Down Expand Up @@ -245,15 +245,15 @@ public StripeMetadataSourceFactory createStripeMetadataSourceFactory(OrcCacheCon
{
StripeMetadataSource stripeMetadataSource = new StorageStripeMetadataSource();
if (orcCacheConfig.isStripeMetadataCacheEnabled()) {
Cache<StripeReader.StripeId, Slice> footerCache = CacheBuilder.newBuilder()
Cache<StripeReader.StripeId, CacheableSlice> footerCache = CacheBuilder.newBuilder()
.maximumWeight(orcCacheConfig.getStripeFooterCacheSize().toBytes())
.weigher((id, footer) -> toIntExact(((Slice) footer).getRetainedSize()))
.weigher((id, footer) -> toIntExact(((CacheableSlice) footer).getSlice().getRetainedSize()))
.expireAfterAccess(orcCacheConfig.getStripeFooterCacheTtlSinceLastAccess().toMillis(), MILLISECONDS)
.recordStats()
.build();
Cache<StripeReader.StripeStreamId, Slice> streamCache = CacheBuilder.newBuilder()
Cache<StripeReader.StripeStreamId, CacheableSlice> streamCache = CacheBuilder.newBuilder()
.maximumWeight(orcCacheConfig.getStripeStreamCacheSize().toBytes())
.weigher((id, stream) -> toIntExact(((Slice) stream).getRetainedSize()))
.weigher((id, stream) -> toIntExact(((CacheableSlice) stream).getSlice().getRetainedSize()))
.expireAfterAccess(orcCacheConfig.getStripeStreamCacheTtlSinceLastAccess().toMillis(), MILLISECONDS)
.recordStats()
.build();
Expand All @@ -262,11 +262,11 @@ public StripeMetadataSourceFactory createStripeMetadataSourceFactory(OrcCacheCon
exporter.export(generatedNameOf(CacheStatsMBean.class, connectorId + "_StripeFooter"), footerCacheStatsMBean);
exporter.export(generatedNameOf(CacheStatsMBean.class, connectorId + "_StripeStream"), streamCacheStatsMBean);

Optional<Cache<StripeReader.StripeStreamId, List<RowGroupIndex>>> rowGroupIndexCache = Optional.empty();
Optional<Cache<StripeReader.StripeStreamId, CacheableRowGroupIndices>> rowGroupIndexCache = Optional.empty();
if (orcCacheConfig.isRowGroupIndexCacheEnabled()) {
rowGroupIndexCache = Optional.of(CacheBuilder.newBuilder()
.maximumWeight(orcCacheConfig.getRowGroupIndexCacheSize().toBytes())
.weigher((id, rowGroupIndices) -> toIntExact(((List<RowGroupIndex>) rowGroupIndices).stream().mapToLong(RowGroupIndex::getRetainedSizeInBytes).sum()))
.weigher((id, rowGroupIndices) -> toIntExact(((CacheableRowGroupIndices) rowGroupIndices).getRowGroupIndices().stream().mapToLong(RowGroupIndex::getRetainedSizeInBytes).sum()))
.expireAfterAccess(orcCacheConfig.getStripeStreamCacheTtlSinceLastAccess().toMillis(), MILLISECONDS)
.recordStats()
.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@
import static com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext;
import static com.facebook.presto.orc.OrcEncoding.ORC;
import static com.facebook.presto.orc.OrcReader.INITIAL_BATCH_SIZE;
import static com.facebook.presto.orc.OrcReader.MODIFICATION_TIME_NOT_SET;
import static com.facebook.presto.parquet.ParquetTypeUtils.getColumnIO;
import static com.facebook.presto.parquet.ParquetTypeUtils.getDescriptors;
import static com.facebook.presto.parquet.ParquetTypeUtils.getParquetTypeByName;
Expand Down Expand Up @@ -505,7 +506,8 @@ private static ConnectorPageSourceWithRowPositions createBatchOrcPageSource(
isCacheable,
dwrfEncryptionProvider,
dwrfKeyProvider,
runtimeStats);
runtimeStats,
MODIFICATION_TIME_NOT_SET);

List<HiveColumnHandle> physicalColumnHandles = new ArrayList<>(regularColumns.size());
ImmutableMap.Builder<Integer, Type> includedColumns = ImmutableMap.builder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ public AbstractOrcRecordReader(
StripeMetadataSource stripeMetadataSource,
boolean cacheable,
RuntimeStats runtimeStats,
Optional<OrcFileIntrospector> fileIntrospector)
Optional<OrcFileIntrospector> fileIntrospector,
long fileModificationTime)
{
requireNonNull(includedColumns, "includedColumns is null");
requireNonNull(predicate, "predicate is null");
Expand Down Expand Up @@ -262,7 +263,8 @@ public AbstractOrcRecordReader(
cacheable,
this.dwrfEncryptionGroupMap,
runtimeStats,
fileIntrospector);
fileIntrospector,
fileModificationTime);

this.streamReaders = requireNonNull(streamReaders, "streamReaders is null");
for (int columnId = 0; columnId < root.getFieldCount(); columnId++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.ExecutionException;

import static com.facebook.presto.common.RuntimeUnit.BYTE;
import static com.facebook.presto.common.RuntimeUnit.NONE;
Expand All @@ -48,11 +47,11 @@ public class CachingStripeMetadataSource
implements StripeMetadataSource
{
private final StripeMetadataSource delegate;
private final Cache<StripeId, Slice> footerSliceCache;
private final Cache<StripeStreamId, Slice> stripeStreamCache;
private final Optional<Cache<StripeStreamId, List<RowGroupIndex>>> rowGroupIndexCache;
private final Cache<StripeId, CacheableSlice> footerSliceCache;
private final Cache<StripeStreamId, CacheableSlice> stripeStreamCache;
private final Optional<Cache<StripeStreamId, CacheableRowGroupIndices>> rowGroupIndexCache;

public CachingStripeMetadataSource(StripeMetadataSource delegate, Cache<StripeId, Slice> footerSliceCache, Cache<StripeStreamId, Slice> stripeStreamCache, Optional<Cache<StripeStreamId, List<RowGroupIndex>>> rowGroupIndexCache)
public CachingStripeMetadataSource(StripeMetadataSource delegate, Cache<StripeId, CacheableSlice> footerSliceCache, Cache<StripeStreamId, CacheableSlice> stripeStreamCache, Optional<Cache<StripeStreamId, CacheableRowGroupIndices>> rowGroupIndexCache)
{
this.delegate = requireNonNull(delegate, "delegate is null");
this.footerSliceCache = requireNonNull(footerSliceCache, "footerSliceCache is null");
Expand All @@ -61,39 +60,56 @@ public CachingStripeMetadataSource(StripeMetadataSource delegate, Cache<StripeId
}

@Override
public Slice getStripeFooterSlice(OrcDataSource orcDataSource, StripeId stripeId, long footerOffset, int footerLength, boolean cacheable)
public Slice getStripeFooterSlice(OrcDataSource orcDataSource, StripeId stripeId, long footerOffset, int footerLength, boolean cacheable, long fileModificationTime)
throws IOException
{
if (!cacheable) {
return delegate.getStripeFooterSlice(orcDataSource, stripeId, footerOffset, footerLength, cacheable, fileModificationTime);
}
try {
if (!cacheable) {
return delegate.getStripeFooterSlice(orcDataSource, stripeId, footerOffset, footerLength, cacheable);
CacheableSlice cacheableSlice = footerSliceCache.getIfPresent(stripeId);
if (cacheableSlice != null) {
if (cacheableSlice.getFileModificationTime() == fileModificationTime) {
return cacheableSlice.getSlice();
}
footerSliceCache.invalidate(stripeId);
// This get call is to increment the miss count for invalidated entries so the stats are recorded correctly.
footerSliceCache.getIfPresent(stripeId);
}
return footerSliceCache.get(stripeId, () -> delegate.getStripeFooterSlice(orcDataSource, stripeId, footerOffset, footerLength, cacheable));
cacheableSlice = new CacheableSlice(delegate.getStripeFooterSlice(orcDataSource, stripeId, footerOffset, footerLength, cacheable, fileModificationTime), fileModificationTime);
footerSliceCache.put(stripeId, cacheableSlice);
return cacheableSlice.getSlice();
}
catch (ExecutionException | UncheckedExecutionException e) {
catch (UncheckedExecutionException e) {
throwIfInstanceOf(e.getCause(), IOException.class);
throw new IOException("Unexpected error in stripe footer reading after footerSliceCache miss", e.getCause());
}
}

@Override
public Map<StreamId, OrcDataSourceInput> getInputs(OrcDataSource orcDataSource, StripeId stripeId, Map<StreamId, DiskRange> diskRanges, boolean cacheable)
public Map<StreamId, OrcDataSourceInput> getInputs(OrcDataSource orcDataSource, StripeId stripeId, Map<StreamId, DiskRange> diskRanges, boolean cacheable, long fileModificationTime)
throws IOException
{
if (!cacheable) {
return delegate.getInputs(orcDataSource, stripeId, diskRanges, cacheable);
return delegate.getInputs(orcDataSource, stripeId, diskRanges, cacheable, fileModificationTime);
}

// Fetch existing stream slice from cache
ImmutableMap.Builder<StreamId, OrcDataSourceInput> inputsBuilder = ImmutableMap.builder();
ImmutableMap.Builder<StreamId, DiskRange> uncachedDiskRangesBuilder = ImmutableMap.builder();
for (Entry<StreamId, DiskRange> entry : diskRanges.entrySet()) {
StripeStreamId stripeStreamId = new StripeStreamId(stripeId, entry.getKey());
if (isCachedStream(entry.getKey().getStreamKind())) {
Slice streamSlice = stripeStreamCache.getIfPresent(new StripeStreamId(stripeId, entry.getKey()));
if (streamSlice != null) {
inputsBuilder.put(entry.getKey(), new OrcDataSourceInput(new BasicSliceInput(streamSlice), streamSlice.length()));
CacheableSlice streamSlice = stripeStreamCache.getIfPresent(stripeStreamId);
if (streamSlice != null && streamSlice.getFileModificationTime() == fileModificationTime) {
inputsBuilder.put(entry.getKey(), new OrcDataSourceInput(new BasicSliceInput(streamSlice.getSlice()), streamSlice.getSlice().length()));
}
else {
if (streamSlice != null) {
stripeStreamCache.invalidate(stripeStreamId);
// This get call is to increment the miss count for invalidated entries so the stats are recorded correctly.
stripeStreamCache.getIfPresent(stripeStreamId);
}
uncachedDiskRangesBuilder.put(entry);
}
}
Expand All @@ -103,12 +119,12 @@ public Map<StreamId, OrcDataSourceInput> getInputs(OrcDataSource orcDataSource,
}

// read ranges and update cache
Map<StreamId, OrcDataSourceInput> uncachedInputs = delegate.getInputs(orcDataSource, stripeId, uncachedDiskRangesBuilder.build(), cacheable);
Map<StreamId, OrcDataSourceInput> uncachedInputs = delegate.getInputs(orcDataSource, stripeId, uncachedDiskRangesBuilder.build(), cacheable, fileModificationTime);
for (Entry<StreamId, OrcDataSourceInput> entry : uncachedInputs.entrySet()) {
if (isCachedStream(entry.getKey().getStreamKind())) {
// We need to rewind the input after eagerly reading the slice.
Slice streamSlice = Slices.wrappedBuffer(entry.getValue().getInput().readSlice(toIntExact(entry.getValue().getInput().length())).getBytes());
stripeStreamCache.put(new StripeStreamId(stripeId, entry.getKey()), streamSlice);
stripeStreamCache.put(new StripeStreamId(stripeId, entry.getKey()), new CacheableSlice(streamSlice, fileModificationTime));
inputsBuilder.put(entry.getKey(), new OrcDataSourceInput(new BasicSliceInput(streamSlice), toIntExact(streamSlice.getRetainedSize())));
}
else {
Expand All @@ -126,24 +142,31 @@ public List<RowGroupIndex> getRowIndexes(
StreamId streamId,
OrcInputStream inputStream,
List<HiveBloomFilter> bloomFilters,
RuntimeStats runtimeStats)
RuntimeStats runtimeStats,
long fileModificationTime)
throws IOException
{
StripeStreamId stripeStreamId = new StripeStreamId(stripId, streamId);
if (rowGroupIndexCache.isPresent()) {
List<RowGroupIndex> rowGroupIndices = rowGroupIndexCache.get().getIfPresent(new StripeStreamId(stripId, streamId));
if (rowGroupIndices != null) {
CacheableRowGroupIndices cacheableRowGroupIndices = rowGroupIndexCache.get().getIfPresent(stripeStreamId);
if (cacheableRowGroupIndices != null && cacheableRowGroupIndices.getFileModificationTime() == fileModificationTime) {
runtimeStats.addMetricValue("OrcRowGroupIndexCacheHit", NONE, 1);
runtimeStats.addMetricValue("OrcRowGroupIndexInMemoryBytesRead", BYTE, rowGroupIndices.stream().mapToLong(RowGroupIndex::getRetainedSizeInBytes).sum());
return rowGroupIndices;
runtimeStats.addMetricValue("OrcRowGroupIndexInMemoryBytesRead", BYTE, cacheableRowGroupIndices.getRowGroupIndices().stream().mapToLong(RowGroupIndex::getRetainedSizeInBytes).sum());
return cacheableRowGroupIndices.getRowGroupIndices();
}
else {
if (cacheableRowGroupIndices != null) {
rowGroupIndexCache.get().invalidate(stripeStreamId);
// This get call is to increment the miss count for invalidated entries so the stats are recorded correctly.
rowGroupIndexCache.get().getIfPresent(stripeStreamId);
}
runtimeStats.addMetricValue("OrcRowGroupIndexCacheHit", NONE, 0);
runtimeStats.addMetricValue("OrcRowGroupIndexStorageBytesRead", BYTE, inputStream.getRetainedSizeInBytes());
}
}
List<RowGroupIndex> rowGroupIndices = delegate.getRowIndexes(metadataReader, hiveWriterVersion, stripId, streamId, inputStream, bloomFilters, runtimeStats);
List<RowGroupIndex> rowGroupIndices = delegate.getRowIndexes(metadataReader, hiveWriterVersion, stripId, streamId, inputStream, bloomFilters, runtimeStats, fileModificationTime);
if (rowGroupIndexCache.isPresent()) {
rowGroupIndexCache.get().put(new StripeStreamId(stripId, streamId), rowGroupIndices);
rowGroupIndexCache.get().put(stripeStreamId, new CacheableRowGroupIndices(rowGroupIndices, fileModificationTime));
}
return rowGroupIndices;
}
Expand Down
Loading

0 comments on commit cf19b92

Please sign in to comment.