Skip to content

Commit

Permalink
Pass temporary tableInfo to tableScan and tableWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
jaystarshot committed Nov 21, 2024
1 parent ec60c13 commit 8303013
Show file tree
Hide file tree
Showing 45 changed files with 144 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public PlanNode visitFilter(FilterNode node, RewriteContext<Void> context)
oldTableScanNode.getAssignments(),
oldTableScanNode.getTableConstraints(),
oldTableScanNode.getCurrentConstraint(),
oldTableScanNode.getEnforcedConstraint());
oldTableScanNode.getEnforcedConstraint(), Optional.empty());

return new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), newTableScanNode, node.getPredicate());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ private Optional<PlanNode> tryCreatingNewScanNode(PlanNode plan)
ImmutableList.copyOf(assignments.keySet()),
assignments.entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, (e) -> (ColumnHandle) (e.getValue()))),
tableScanNode.getCurrentConstraint(),
tableScanNode.getEnforcedConstraint()));
tableScanNode.getEnforcedConstraint(), Optional.empty()));
}

@Override
Expand Down Expand Up @@ -288,7 +288,7 @@ public PlanNode visitFilter(FilterNode node, Void context)
oldTableScanNode.getOutputVariables(),
oldTableScanNode.getAssignments(),
oldTableScanNode.getCurrentConstraint(),
oldTableScanNode.getEnforcedConstraint());
oldTableScanNode.getEnforcedConstraint(), Optional.empty());

return new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), newTableScanNode, node.getPredicate());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ private Optional<PlanNode> tryCreatingNewScanNode(PlanNode plan)
assignments.entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, (e) -> (ColumnHandle) (e.getValue()))),
tableScanNode.getTableConstraints(),
tableScanNode.getCurrentConstraint(),
tableScanNode.getEnforcedConstraint()));
tableScanNode.getEnforcedConstraint(), Optional.empty()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ protected TableScanNode tableScan(PlanBuilder planBuilder, DruidTableHandle conn
variables,
assignments.build(),
TupleDomain.all(),
TupleDomain.all());
TupleDomain.all(), Optional.empty());
}

protected FilterNode filter(PlanBuilder planBuilder, PlanNode source, RowExpression predicate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ private static TableScanNode getTableScanNode(
tableScan.getAssignments(),
tableScan.getTableConstraints(),
pushdownFilterResult.getLayout().getPredicate(),
TupleDomain.all());
TupleDomain.all(), tableScan.getTemporaryTableInfo());
}

private static ExtractionResult intersectExtractionResult(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public PlanNode visitTableScan(TableScanNode tableScan, RewriteContext<Void> con
tableScan.getAssignments(),
tableScan.getTableConstraints(),
tableScan.getCurrentConstraint(),
tableScan.getEnforcedConstraint());
tableScan.getEnforcedConstraint(), tableScan.getTemporaryTableInfo());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ private Optional<PlanNode> tryPartialAggregationPushdown(PlanNode plan)
ImmutableMap.copyOf(assignments),
oldTableScanNode.getTableConstraints(),
oldTableScanNode.getCurrentConstraint(),
oldTableScanNode.getEnforcedConstraint()));
oldTableScanNode.getEnforcedConstraint(), oldTableScanNode.getTemporaryTableInfo()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public void testPersistentCteWithTimeStampWithTimeZoneType()
" (CAST('2023-12-31 23:59:59.999 UTC' AS TIMESTAMP WITH TIME ZONE))" +
" ) AS t(ts)" +
")" +
"SELECT ts FROM cte";
"SELECT * FROM cte JOIN cte ON true";
QueryRunner queryRunner = getQueryRunner();
verifyResults(queryRunner, testQuery, ImmutableList.of(generateMaterializedCTEInformation("cte", 1, false, true)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ private TableScanNode createDeletesTableScan(ImmutableMap<VariableReferenceExpre
outputs,
deleteColumnAssignments,
TupleDomain.all(),
TupleDomain.all());
TupleDomain.all(), Optional.empty());
}

/**
Expand Down Expand Up @@ -382,7 +382,8 @@ private TableScanNode createNewRoot(TableScanNode node, IcebergTableHandle icebe
assignmentsBuilder.build(),
node.getTableConstraints(),
node.getCurrentConstraint(),
node.getEnforcedConstraint());
node.getEnforcedConstraint(),
node.getTemporaryTableInfo());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ public PlanNode visitFilter(FilterNode filter, RewriteContext<Void> context)
.intersect(tableScan.getCurrentConstraint()),
predicateNotChangedBySimplification ?
identityPartitionColumnPredicate.intersect(tableScan.getEnforcedConstraint()) :
tableScan.getEnforcedConstraint());
tableScan.getEnforcedConstraint(), Optional.empty());

if (TRUE_CONSTANT.equals(remainingFilterExpression) && predicateNotChangedBySimplification) {
return newTableScan;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,11 @@ public final class SqlStageExecution

private final Map<InternalNode, Set<RemoteTask>> tasks = new ConcurrentHashMap<>();

private final Map<TaskId, RemoteTask> taskMap = new ConcurrentHashMap<>();

@GuardedBy("this")
private final AtomicInteger nextTaskId = new AtomicInteger();

@GuardedBy("this")
private final Set<TaskId> allTasks = newConcurrentHashSet();
@GuardedBy("this")
Expand Down Expand Up @@ -512,6 +515,9 @@ public synchronized Set<RemoteTask> scheduleSplits(InternalNode node, Multimap<P

private synchronized RemoteTask scheduleTask(InternalNode node, TaskId taskId, Multimap<PlanNodeId, Split> sourceSplits)
{
if (taskMap.containsKey(taskId)) {
return taskMap.get(taskId);
}
checkArgument(!allTasks.contains(taskId), "A task with id %s already exists", taskId);

ImmutableMultimap.Builder<PlanNodeId, Split> initialSplits = ImmutableMultimap.builder();
Expand Down Expand Up @@ -557,7 +563,7 @@ private synchronized RemoteTask scheduleTask(InternalNode node, TaskId taskId, M
// stage finished while we were scheduling this task
task.abort();
}

taskMap.put(taskId, task);
return task;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.TemporaryTableInfo;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
Expand Down Expand Up @@ -78,6 +79,7 @@
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.concat;
import static java.lang.String.format;
import static java.util.Collections.emptyList;
import static java.util.function.Function.identity;

// Planner Util for creating temporary tables
Expand All @@ -95,7 +97,8 @@ public static TableScanNode createTemporaryTableScan(
TableHandle tableHandle,
List<VariableReferenceExpression> outputVariables,
Map<VariableReferenceExpression, ColumnMetadata> variableToColumnMap,
Optional<PartitioningMetadata> expectedPartitioningMetadata)
Optional<PartitioningMetadata> expectedPartitioningMetadata,
Optional<String> cteId)
{
Map<String, ColumnHandle> columnHandles = metadata.getColumnHandles(session, tableHandle);
Map<VariableReferenceExpression, ColumnMetadata> outputColumns = outputVariables.stream()
Expand All @@ -122,11 +125,14 @@ public static TableScanNode createTemporaryTableScan(
return new TableScanNode(
sourceLocation,
idAllocator.getNextId(),
Optional.empty(),
selectedLayout.getLayout().getNewTableHandle(),
outputVariables,
assignments,
emptyList(),
TupleDomain.all(),
TupleDomain.all(),
TupleDomain.all());
cteId.map(TemporaryTableInfo::new));
}

public static Map<VariableReferenceExpression, ColumnMetadata> assignTemporaryTableColumnNames(Collection<VariableReferenceExpression> outputVariables,
Expand Down Expand Up @@ -177,7 +183,8 @@ public static TableFinishNode createTemporaryTableWriteWithoutExchanges(
TableHandle tableHandle,
List<VariableReferenceExpression> outputs,
Map<VariableReferenceExpression, ColumnMetadata> variableToColumnMap,
VariableReferenceExpression outputVar)
VariableReferenceExpression outputVar,
Optional<String> cteId)
{
SchemaTableName schemaTableName = metadata.getTableMetadata(session, tableHandle).getTable();
TableWriterNode.InsertReference insertReference = new TableWriterNode.InsertReference(tableHandle, schemaTableName);
Expand Down Expand Up @@ -208,7 +215,7 @@ public static TableFinishNode createTemporaryTableWriteWithoutExchanges(
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.of(Boolean.TRUE)),
cteId.map(TemporaryTableInfo::new)),
Optional.of(insertReference),
outputVar,
Optional.empty(),
Expand Down Expand Up @@ -347,7 +354,8 @@ public static TableFinishNode createTemporaryTableWriteWithExchanges(
Optional.empty(),
enableStatsCollectionForTemporaryTable ? Optional.of(localAggregations.getPartialAggregation()) : Optional.empty(),
Optional.empty(),
Optional.of(Boolean.TRUE))),
// ToDO: handle this better
Optional.empty())),
variableAllocator.newVariable("intermediaterows", BIGINT),
variableAllocator.newVariable("intermediatefragments", VARBINARY),
variableAllocator.newVariable("intermediatetablecommitcontext", VARBINARY),
Expand All @@ -369,7 +377,7 @@ public static TableFinishNode createTemporaryTableWriteWithExchanges(
Optional.empty(),
enableStatsCollectionForTemporaryTable ? Optional.of(aggregations.getPartialAggregation()) : Optional.empty(),
Optional.empty(),
Optional.of(Boolean.TRUE));
Optional.empty());
}

return new TableFinishNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ private PlanNode createRemoteMaterializedExchange(ExchangeNode exchange, Rewrite
temporaryTableHandle,
exchange.getOutputVariables(),
variableToColumnMap,
Optional.of(partitioningMetadata));
Optional.of(partitioningMetadata), Optional.empty());

checkArgument(
!exchange.getPartitioningScheme().isReplicateNullsAndAny(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ private RelationPlan createAnalyzePlan(Analysis analysis, Analyze analyzeStateme
.putAll(tableScanOutputs.stream().collect(toImmutableMap(identity(), identity())))
.putAll(tableStatisticAggregation.getAdditionalVariables())
.build();
TableScanNode scanNode = new TableScanNode(getSourceLocation(analyzeStatement), idAllocator.getNextId(), targetTable, tableScanOutputs, variableToColumnHandle.build(), TupleDomain.all(), TupleDomain.all());
TableScanNode scanNode = new TableScanNode(getSourceLocation(analyzeStatement), idAllocator.getNextId(), targetTable, tableScanOutputs, variableToColumnHandle.build(), TupleDomain.all(), TupleDomain.all(), Optional.empty());
PlanNode project = PlannerUtils.addProjections(scanNode, idAllocator, assignments);
PlanNode planNode = new StatisticsWriterNode(
getSourceLocation(analyzeStatement),
Expand Down Expand Up @@ -445,7 +445,7 @@ private RelationPlan createTableWriterPlan(
// the data consumed by the TableWriteOperator
Optional.of(aggregations.getPartialAggregation()),
Optional.empty(),
Optional.of(Boolean.FALSE)),
Optional.empty()),
Optional.of(target),
variableAllocator.newVariable("rows", BIGINT),
// final aggregation is run within the TableFinishOperator to summarize collected statistics
Expand Down Expand Up @@ -474,7 +474,7 @@ private RelationPlan createTableWriterPlan(
preferredShufflePartitioningScheme,
Optional.empty(),
Optional.empty(),
Optional.of(Boolean.FALSE)),
Optional.empty()),
Optional.of(target),
variableAllocator.newVariable("rows", BIGINT),
Optional.empty(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ public static Set<PlanNodeId> getOutputTableWriterNodeIds(PlanNode plan)
return stream(forTree(PlanNode::getSources).depthFirstPreOrder(plan))
.filter(node -> node instanceof TableWriterNode)
.map(node -> (TableWriterNode) node)
.filter(tableWriterNode -> !tableWriterNode.getIsTemporaryTableWriter().orElse(false))
.filter(tableWriterNode -> !tableWriterNode.getTemporaryTableInfo().isPresent())
.map(TableWriterNode::getId)
.collect(toImmutableSet());
}
Expand Down Expand Up @@ -306,7 +306,7 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext<Void> context)
node.getOutputVariables(),
node.getAssignments(),
node.getCurrentConstraint(),
node.getEnforcedConstraint());
node.getEnforcedConstraint(), node.getTemporaryTableInfo());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ private static TableScanNode cloneTableScan(TableScanNode scanNode, Session sess
newAssignments,
scanNode.getTableConstraints(),
scanNode.getCurrentConstraint(),
scanNode.getEnforcedConstraint());
scanNode.getEnforcedConstraint(), scanNode.getTemporaryTableInfo());
}

public static PlanNode clonePlanNode(PlanNode planNode, Session session, Metadata metadata, PlanNodeIdAllocator planNodeIdAllocator, List<VariableReferenceExpression> fieldsToKeep, Map<VariableReferenceExpression, VariableReferenceExpression> varMap)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ public DeleteNode plan(Delete node)

// create table scan
List<VariableReferenceExpression> outputVariables = outputVariablesBuilder.build();
PlanNode tableScan = new TableScanNode(getSourceLocation(node), idAllocator.getNextId(), handle, outputVariables, columns.build(), TupleDomain.all(), TupleDomain.all());
PlanNode tableScan = new TableScanNode(getSourceLocation(node), idAllocator.getNextId(), handle, outputVariables, columns.build(), TupleDomain.all(), TupleDomain.all(), Optional.empty());
Scope scope = Scope.builder().withRelationType(RelationId.anonymous(), new RelationType(fields.build())).build();
RelationPlan relationPlan = new RelationPlan(tableScan, scope, outputVariables);

Expand Down Expand Up @@ -344,7 +344,7 @@ public UpdateNode plan(Update node)

// create table scan
List<VariableReferenceExpression> outputVariables = outputVariablesBuilder.build();
PlanNode tableScan = new TableScanNode(getSourceLocation(node), idAllocator.getNextId(), handle, outputVariables, columns.build(), TupleDomain.all(), TupleDomain.all());
PlanNode tableScan = new TableScanNode(getSourceLocation(node), idAllocator.getNextId(), handle, outputVariables, columns.build(), TupleDomain.all(), TupleDomain.all(), Optional.empty());
Scope scope = Scope.builder().withRelationType(RelationId.anonymous(), new RelationType(fields.build())).build();
RelationPlan relationPlan = new RelationPlan(tableScan, scope, outputVariables);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ protected RelationPlan visitTable(Table node, SqlPlannerContext context)
List<VariableReferenceExpression> outputVariables = outputVariablesBuilder.build();
List<TableConstraint<ColumnHandle>> tableConstraints = metadata.getTableMetadata(session, handle).getMetadata().getTableConstraintsHolder().getTableConstraintsWithColumnHandles();
context.incrementLeafNodes(session);
PlanNode root = new TableScanNode(getSourceLocation(node.getLocation()), idAllocator.getNextId(), handle, outputVariables, columns.build(), tableConstraints, TupleDomain.all(), TupleDomain.all());
PlanNode root = new TableScanNode(getSourceLocation(node.getLocation()), idAllocator.getNextId(), handle, outputVariables, columns.build(),
tableConstraints, TupleDomain.all(), TupleDomain.all(), Optional.empty());

return new RelationPlan(root, scope, outputVariables);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ public Result apply(TableScanNode tableScanNode, Captures captures, Context cont
tableScanNode.getAssignments(),
tableScanNode.getTableConstraints(),
layout.getLayout().getPredicate(),
TupleDomain.all()));
TupleDomain.all(), tableScanNode.getTemporaryTableInfo()));
}
}

Expand Down Expand Up @@ -324,7 +324,8 @@ private static PlanNode pushPredicateIntoTableScan(
node.getAssignments(),
node.getTableConstraints(),
layout.getLayout().getPredicate(),
computeEnforced(newDomain, layout.getUnenforcedConstraint()));
computeEnforced(newDomain, layout.getUnenforcedConstraint()),
node.getTemporaryTableInfo());

// The order of the arguments to combineConjuncts matters:
// * Unenforced constraints go first because they can only be simple column references,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ protected Optional<PlanNode> pushDownProjectOff(PlanNodeIdAllocator idAllocator,
filterKeys(tableScanNode.getAssignments(), referencedOutputs::contains),
tableScanNode.getTableConstraints(),
tableScanNode.getCurrentConstraint(),
tableScanNode.getEnforcedConstraint()));
tableScanNode.getEnforcedConstraint(), tableScanNode.getTemporaryTableInfo()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ public Result apply(TableWriterNode node, Captures captures, Context context)
node.getPreferredShufflePartitioningScheme(),
rewrittenStatisticsAggregation,
node.getTaskCountIfScaledWriter(),
node.getIsTemporaryTableWriter()));
node.getTemporaryTableInfo()));
}
return Result.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,6 @@ public Result apply(TableWriterNode node, Captures captures, Context context)
node.getPreferredShufflePartitioningScheme(),
node.getStatisticsAggregation(),
Optional.of(initialTaskNumber),
node.getIsTemporaryTableWriter()));
node.getTemporaryTableInfo()));
}
}
Loading

0 comments on commit 8303013

Please sign in to comment.