Skip to content

Commit

Permalink
support infer join when comapring mv
Browse files Browse the repository at this point in the history
  • Loading branch information
keanji-x committed Dec 25, 2023
1 parent 1d984e0 commit 1bb154f
Show file tree
Hide file tree
Showing 7 changed files with 362 additions and 181 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.rules.exploration.mv.ComparisonResult;
import org.apache.doris.nereids.rules.exploration.mv.LogicalCompatibilityContext;
import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughJoin;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
Expand All @@ -44,20 +42,14 @@
import org.apache.doris.nereids.util.PlanUtils;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -272,11 +264,11 @@ private void makeFilterConflictRules(JoinEdge joinEdge) {
filterEdges.forEach(e -> {
if (LongBitmap.isSubset(e.getReferenceNodes(), leftSubNodes)
&& !PushDownFilterThroughJoin.COULD_PUSH_THROUGH_LEFT.contains(joinEdge.getJoinType())) {
e.addRejectEdge(joinEdge);
e.addLeftRejectEdge(joinEdge);
}
if (LongBitmap.isSubset(e.getReferenceNodes(), rightSubNodes)
&& !PushDownFilterThroughJoin.COULD_PUSH_THROUGH_RIGHT.contains(joinEdge.getJoinType())) {
e.addRejectEdge(joinEdge);
e.addRightRejectEdge(joinEdge);
}
});
}
Expand All @@ -293,23 +285,23 @@ private void makeJoinConflictRules(JoinEdge edgeB) {
JoinEdge childA = joinEdges.get(i);
if (!JoinType.isAssoc(childA.getJoinType(), edgeB.getJoinType())) {
leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getLeftSubNodes(joinEdges));
childA.addRejectEdge(edgeB);
childA.addLeftRejectEdge(edgeB);
}
if (!JoinType.isLAssoc(childA.getJoinType(), edgeB.getJoinType())) {
leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getRightSubNodes(joinEdges));
childA.addRejectEdge(edgeB);
childA.addLeftRejectEdge(edgeB);
}
}

for (int i = rightSubTreeEdges.nextSetBit(0); i >= 0; i = rightSubTreeEdges.nextSetBit(i + 1)) {
JoinEdge childA = joinEdges.get(i);
if (!JoinType.isAssoc(edgeB.getJoinType(), childA.getJoinType())) {
rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getRightSubNodes(joinEdges));
childA.addRejectEdge(edgeB);
childA.addRightRejectEdge(edgeB);
}
if (!JoinType.isRAssoc(edgeB.getJoinType(), childA.getJoinType())) {
rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getLeftSubNodes(joinEdges));
childA.addRejectEdge(edgeB);
childA.addRightRejectEdge(edgeB);
}
}
edgeB.setLeftExtendedNodes(leftRequired);
Expand Down Expand Up @@ -597,157 +589,6 @@ public int edgeSize() {
return joinEdges.size() + filterEdges.size();
}

/**
* compare hypergraph
*
* @param viewHG the compared hyper graph
* @return Comparison result
*/
public ComparisonResult isLogicCompatible(HyperGraph viewHG, LogicalCompatibilityContext ctx) {
// 1 try to construct a map which can be mapped from edge to edge
Map<Edge, Edge> queryToView = constructMapWithNode(viewHG, ctx.getQueryToViewNodeIDMapping());

// 2. compare them by expression and extract residual expr
ComparisonResult.Builder builder = new ComparisonResult.Builder();
ComparisonResult edgeCompareRes = compareEdgesWithExpr(queryToView, ctx.getQueryToViewEdgeExpressionMapping());
if (edgeCompareRes.isInvalid()) {
return ComparisonResult.INVALID;
}
builder.addComparisonResult(edgeCompareRes);

// 3. pull join edge of view is no sense, so reject them
if (!queryToView.values().containsAll(viewHG.joinEdges)) {
return ComparisonResult.INVALID;
}

// 4. process residual edges
List<Expression> residualQueryJoin =
processOrphanEdges(Sets.difference(Sets.newHashSet(joinEdges), queryToView.keySet()));
if (residualQueryJoin == null) {
return ComparisonResult.INVALID;
}
builder.addQueryExpressions(residualQueryJoin);

List<Expression> residualQueryFilter =
processOrphanEdges(Sets.difference(Sets.newHashSet(filterEdges), queryToView.keySet()));
if (residualQueryFilter == null) {
return ComparisonResult.INVALID;
}
builder.addQueryExpressions(residualQueryFilter);

List<Expression> residualViewFilter =
processOrphanEdges(
Sets.difference(Sets.newHashSet(viewHG.filterEdges), Sets.newHashSet(queryToView.values())));
if (residualViewFilter == null) {
return ComparisonResult.INVALID;
}
builder.addViewExpressions(residualViewFilter);

return builder.build();
}

private List<Expression> processOrphanEdges(Set<Edge> edges) {
List<Expression> expressions = new ArrayList<>();
for (Edge edge : edges) {
if (!edge.canPullUp()) {
return null;
}
expressions.addAll(edge.getExpressions());
}
return expressions;
}

private Map<Edge, Edge> constructMapWithNode(HyperGraph viewHG, Map<Integer, Integer> nodeMap) {
// TODO use hash map to reduce loop
Map<Edge, Edge> joinEdgeMap = joinEdges.stream().map(qe -> {
Optional<JoinEdge> viewEdge = viewHG.joinEdges.stream()
.filter(ve -> compareEdgeWithNode(qe, ve, nodeMap)).findFirst();
return Pair.of(qe, viewEdge);
}).filter(e -> e.second.isPresent()).collect(ImmutableMap.toImmutableMap(p -> p.first, p -> p.second.get()));
Map<Edge, Edge> filterEdgeMap = filterEdges.stream().map(qe -> {
Optional<FilterEdge> viewEdge = viewHG.filterEdges.stream()
.filter(ve -> compareEdgeWithNode(qe, ve, nodeMap)).findFirst();
return Pair.of(qe, viewEdge);
}).filter(e -> e.second.isPresent()).collect(ImmutableMap.toImmutableMap(p -> p.first, p -> p.second.get()));
return ImmutableMap.<Edge, Edge>builder().putAll(joinEdgeMap).putAll(filterEdgeMap).build();
}

private boolean compareEdgeWithNode(Edge t, Edge o, Map<Integer, Integer> nodeMap) {
if (t instanceof FilterEdge && o instanceof FilterEdge) {
return compareEdgeWithFilter((FilterEdge) t, (FilterEdge) o, nodeMap);
} else if (t instanceof JoinEdge && o instanceof JoinEdge) {
return compareJoinEdge((JoinEdge) t, (JoinEdge) o, nodeMap);
}
return false;
}

private boolean compareEdgeWithFilter(FilterEdge t, FilterEdge o, Map<Integer, Integer> nodeMap) {
long tChild = t.getReferenceNodes();
long oChild = o.getReferenceNodes();
return compareNodeMap(tChild, oChild, nodeMap);
}

private boolean compareJoinEdge(JoinEdge t, JoinEdge o, Map<Integer, Integer> nodeMap) {
long tLeft = t.getLeftExtendedNodes();
long tRight = t.getRightExtendedNodes();
long oLeft = o.getLeftExtendedNodes();
long oRight = o.getRightExtendedNodes();
if (!t.getJoinType().equals(o.getJoinType()) && !t.getJoinType().swap().equals(o.getJoinType())) {
return false;
}
boolean matched = false;
if (t.getJoinType().swap().equals(o.getJoinType())) {
matched |= compareNodeMap(tRight, oLeft, nodeMap) && compareNodeMap(tLeft, oRight, nodeMap);
}
matched |= compareNodeMap(tLeft, oLeft, nodeMap) && compareNodeMap(tRight, oRight, nodeMap);
return matched;
}

private boolean compareNodeMap(long bitmap1, long bitmap2, Map<Integer, Integer> nodeIDMap) {
long newBitmap1 = LongBitmap.newBitmap();
for (int i : LongBitmap.getIterator(bitmap1)) {
int mappedI = nodeIDMap.getOrDefault(i, 0);
newBitmap1 = LongBitmap.set(newBitmap1, mappedI);
}
return bitmap2 == newBitmap1;
}

private ComparisonResult compareEdgesWithExpr(Map<Edge, Edge> queryToViewedgeMap,
Map<Expression, Expression> queryToView) {
ComparisonResult.Builder builder = new ComparisonResult.Builder();
for (Entry<Edge, Edge> e : queryToViewedgeMap.entrySet()) {
ComparisonResult res = compareEdgeWithExpr(e.getKey(), e.getValue(), queryToView);
if (res.isInvalid()) {
return ComparisonResult.INVALID;
}
builder.addComparisonResult(res);
}
return builder.build();
}

private ComparisonResult compareEdgeWithExpr(Edge query, Edge view, Map<Expression, Expression> queryToView) {
Set<? extends Expression> queryExprSet = query.getExpressionSet();
Set<? extends Expression> viewExprSet = view.getExpressionSet();

Set<Expression> equalViewExpr = new HashSet<>();
List<Expression> residualQueryExpr = new ArrayList<>();
for (Expression queryExpr : queryExprSet) {
if (queryToView.containsKey(queryExpr) && viewExprSet.contains(queryToView.get(queryExpr))) {
equalViewExpr.add(queryToView.get(queryExpr));
} else {
residualQueryExpr.add(queryExpr);
}
}
List<Expression> residualViewExpr = ImmutableList.copyOf(Sets.difference(viewExprSet, equalViewExpr));
if (!residualViewExpr.isEmpty() && !view.canPullUp()) {
return ComparisonResult.INVALID;
}
if (!residualQueryExpr.isEmpty() && !query.canPullUp()) {
return ComparisonResult.INVALID;
}
return new ComparisonResult(residualQueryExpr, residualViewExpr);
}

/**
* For the given hyperGraph, make a textual representation in the form
* of a dotty graph. You can save this to a file and then use Graphviz
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.common.collect.ImmutableSet;

import java.util.BitSet;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

Expand Down Expand Up @@ -53,7 +54,8 @@ public abstract class Edge {
// record all sub nodes behind in this operator. It's T function in paper
private final long subTreeNodes;

private long rejectNodes = 0;
private final Set<JoinEdge> leftRejectEdges;
private final Set<JoinEdge> rightRejectEdges;

/**
* Create simple edge.
Expand All @@ -69,14 +71,36 @@ public abstract class Edge {
this.leftExtendedNodes = leftRequiredNodes;
this.rightExtendedNodes = rightRequiredNodes;
this.subTreeNodes = subTreeNodes;
this.leftRejectEdges = new HashSet<>();
this.rightRejectEdges = new HashSet<>();
}

public boolean isSimple() {
return LongBitmap.getCardinality(leftExtendedNodes) == 1 && LongBitmap.getCardinality(rightExtendedNodes) == 1;
}

public void addRejectEdge(Edge edge) {
rejectNodes = LongBitmap.newBitmapUnion(edge.getReferenceNodes(), rejectNodes);
public void addLeftRejectEdge(JoinEdge edge) {
leftRejectEdges.add(edge);
}

public void addRightRejectEdge(JoinEdge edge) {
rightRejectEdges.add(edge);
}

public void addLeftRejectEdges(Set<JoinEdge> edge) {
leftRejectEdges.addAll(edge);
}

public void addRightRejectEdges(Set<JoinEdge> edge) {
rightRejectEdges.addAll(edge);
}

public Set<JoinEdge> getLeftRejectEdge() {
return ImmutableSet.copyOf(leftRejectEdges);
}

public Set<JoinEdge> getRightRejectEdge() {
return ImmutableSet.copyOf(rightRejectEdges);
}

public void addLeftExtendNode(long left) {
Expand Down Expand Up @@ -183,16 +207,6 @@ public Set<? extends Expression> getExpressionSet() {
return ImmutableSet.copyOf(getExpressions());
}

public boolean canPullUp() {
// Only inner join and filter with none rejectNodes can be pull up
return rejectNodes == 0
&& !(this instanceof JoinEdge && !((JoinEdge) this).getJoinType().isInnerJoin());
}

public long getRejectNodes() {
return rejectNodes;
}

public Expression getExpression(int i) {
return getExpressions().get(i);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ public JoinEdge(LogicalJoin<? extends Plan, ? extends Plan> join, int index,
this.join = join;
}

/**
* swap the edge
*/
public JoinEdge swap() {
JoinEdge swapEdge = new
JoinEdge(join.swap(), getIndex(), getRightChildEdges(),
getLeftChildEdges(), getSubTreeNodes(), getRightRequiredNodes(), getLeftRequiredNodes());
swapEdge.addLeftRejectEdges(getLeftRejectEdge());
swapEdge.addRightRejectEdges(getRightRejectEdge());
return swapEdge;
}

public JoinType getJoinType() {
return join.getJoinType();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,24 @@ public Builder addComparisonResult(ComparisonResult comparisonResult) {
return this;
}

public Builder addQueryExpressions(Collection<Expression> expressions) {
public Builder addQueryExpressions(Collection<? extends Expression> expressions) {
queryBuilder.addAll(expressions);
return this;
}

public Builder addViewExpressions(Collection<Expression> expressions) {
public Builder addViewExpressions(Collection<? extends Expression> expressions) {
viewBuilder.addAll(expressions);
return this;
}

public boolean isInvalid() {
return !valid;
}

public ComparisonResult build() {
if (isInvalid()) {
return ComparisonResult.INVALID;
}
return new ComparisonResult(queryBuilder.build(), viewBuilder.build(), valid);
}
}
Expand Down
Loading

0 comments on commit 1bb154f

Please sign in to comment.