Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sql): GROUP BY expr && ORDER BY expr #465

Merged
merged 9 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions antlr/src/main/antlr4/cn/edu/tsinghua/iginx/sql/Sql.g4
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ functionName
;

caseSpecification
: simipleCase
: simpleCase
| searchedCase
;

simipleCase
simpleCase
: CASE expression simpleWhenClause (simpleWhenClause)* elseClause? END
;

Expand Down Expand Up @@ -303,7 +303,12 @@ specialClause
;

groupByClause
: GROUP BY path (COMMA path)*
: GROUP BY groupByItem (COMMA groupByItem)*
;

groupByItem
: path
| expression
;

havingClause
Expand All @@ -315,7 +320,7 @@ orderByClause
;

orderItem
: path (DESC | ASC)?
: (path | expression) (DESC | ASC)?
;

downsampleClause
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -583,15 +583,15 @@ private static Operator buildLimit(SelectStatement selectStatement, Operator roo
* @return 添加了Sort操作符的根节点;如果没有Order By子句,返回原根节点
*/
private static Operator buildOrderByPaths(SelectStatement selectStatement, Operator root) {
if (selectStatement.getOrderByPaths().isEmpty()) {
if (selectStatement.getOrderByExpressions().isEmpty()) {
return root;
}
List<Sort.SortType> sortTypes = new ArrayList<>();
selectStatement
.getAscendingList()
.forEach(
isAscending -> sortTypes.add(isAscending ? Sort.SortType.ASC : Sort.SortType.DESC));
return new Sort(new OperatorSource(root), selectStatement.getOrderByPaths(), sortTypes);
return new Sort(new OperatorSource(root), selectStatement.getOrderByExpressions(), sortTypes);
}

/**
Expand Down Expand Up @@ -662,7 +662,7 @@ private Operator buildGroupByQuery(UnarySelectStatement selectStatement, Operato
List<FunctionCall> functionCallList =
getFunctionCallList(selectStatement, MappingType.SetMapping);
return new GroupBy(
new OperatorSource(root), selectStatement.getGroupByPaths(), functionCallList);
new OperatorSource(root), selectStatement.getGroupByExpressions(), functionCallList);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import static cn.edu.tsinghua.iginx.engine.shared.operator.type.OperatorType.isUnaryOperator;

import cn.edu.tsinghua.iginx.engine.shared.expr.BaseExpression;
import cn.edu.tsinghua.iginx.engine.shared.expr.Expression;
import cn.edu.tsinghua.iginx.engine.shared.function.FunctionCall;
import cn.edu.tsinghua.iginx.engine.shared.function.FunctionParams;
import cn.edu.tsinghua.iginx.engine.shared.function.FunctionUtils;
Expand Down Expand Up @@ -316,18 +317,19 @@ private static Operator pushDownApply(Operator root, List<String> correlatedVari
root =
new GroupBy(
new OperatorSource(pushDownApply(apply, correlatedVariables)),
correlatedVariables,
correlatedVariables.stream().map(BaseExpression::new).collect(Collectors.toList()),
setTransform.getFunctionCallList());
break;
case GroupBy:
GroupBy groupBy = (GroupBy) operatorB;
apply.setSourceB(groupBy.getSource());
List<String> groupByCols = groupBy.getGroupByCols();
groupByCols.addAll(correlatedVariables);
List<Expression> groupByExpressions = groupBy.getGroupByExpressions();
groupByExpressions.addAll(
correlatedVariables.stream().map(BaseExpression::new).collect(Collectors.toList()));
root =
new GroupBy(
new OperatorSource(pushDownApply(apply, correlatedVariables)),
groupByCols,
groupByExpressions,
groupBy.getFunctionCallList());
break;
case Rename:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,11 @@ private RowStream executeSelect(Select select, Table table) throws PhysicalExcep
}

private RowStream executeSort(Sort sort, Table table) throws PhysicalException {
RowTransform preRowTransform = HeaderUtils.checkSortHeader(table.getHeader(), sort);
if (preRowTransform != null) {
table = transformToTable(executeRowTransform(preRowTransform, table));
}

List<Boolean> ascendingList = sort.getAscendingList();
RowUtils.sortRows(table.getRows(), ascendingList, sort.getSortByCols());
return table;
Expand Down Expand Up @@ -483,6 +488,11 @@ private RowStream executeAddSchemaPrefix(AddSchemaPrefix addSchemaPrefix, Table
}

private RowStream executeGroupBy(GroupBy groupBy, Table table) throws PhysicalException {
RowTransform preRowTransform = HeaderUtils.checkGroupByHeader(table.getHeader(), groupBy);
if (preRowTransform != null) {
table = transformToTable(executeRowTransform(preRowTransform, table));
}

List<Row> rows = RowUtils.cacheGroupByResult(groupBy, table);
if (rows.isEmpty()) {
return Table.EMPTY_TABLE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import cn.edu.tsinghua.iginx.engine.physical.exception.UnexpectedOperatorException;
import cn.edu.tsinghua.iginx.engine.physical.memory.execute.OperatorMemoryExecutor;
import cn.edu.tsinghua.iginx.engine.physical.memory.execute.Table;
import cn.edu.tsinghua.iginx.engine.physical.memory.execute.utils.HeaderUtils;
import cn.edu.tsinghua.iginx.engine.physical.memory.execute.utils.RowUtils;
import cn.edu.tsinghua.iginx.engine.shared.Constants;
import cn.edu.tsinghua.iginx.engine.shared.RequestContext;
Expand Down Expand Up @@ -198,6 +199,11 @@ private RowStream executeSelect(Select select, RowStream stream) {
}

private RowStream executeSort(Sort sort, RowStream stream) throws PhysicalException {
RowTransform preRowTransform = HeaderUtils.checkSortHeader(stream.getHeader(), sort);
if (preRowTransform != null) {
stream = executeRowTransform(preRowTransform, stream);
}

return new SortLazyStream(sort, stream);
}

Expand Down Expand Up @@ -270,7 +276,12 @@ private RowStream executeAddSchemaPrefix(AddSchemaPrefix addSchemaPrefix, RowStr
return new AddSchemaPrefixLazyStream(addSchemaPrefix, stream);
}

private RowStream executeGroupBy(GroupBy groupBy, RowStream stream) {
private RowStream executeGroupBy(GroupBy groupBy, RowStream stream) throws PhysicalException {
RowTransform preRowTransform = HeaderUtils.checkGroupByHeader(stream.getHeader(), groupBy);
if (preRowTransform != null) {
stream = executeRowTransform(preRowTransform, stream);
}

return new GroupByLazyStream(groupBy, stream);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
*/
package cn.edu.tsinghua.iginx.engine.physical.memory.execute.utils;

import static cn.edu.tsinghua.iginx.engine.shared.function.system.ArithmeticExpr.ARITHMETIC_EXPR;
import static cn.edu.tsinghua.iginx.engine.shared.function.system.utils.ValueUtils.isNumericType;
import static cn.edu.tsinghua.iginx.sql.SQLConstant.DOT;
import static cn.edu.tsinghua.iginx.thrift.DataType.BOOLEAN;
Expand All @@ -28,13 +29,26 @@
import cn.edu.tsinghua.iginx.engine.physical.exception.PhysicalException;
import cn.edu.tsinghua.iginx.engine.shared.data.read.Field;
import cn.edu.tsinghua.iginx.engine.shared.data.read.Header;
import cn.edu.tsinghua.iginx.engine.shared.expr.BaseExpression;
import cn.edu.tsinghua.iginx.engine.shared.expr.Expression;
import cn.edu.tsinghua.iginx.engine.shared.expr.KeyExpression;
import cn.edu.tsinghua.iginx.engine.shared.function.Function;
import cn.edu.tsinghua.iginx.engine.shared.function.FunctionCall;
import cn.edu.tsinghua.iginx.engine.shared.function.FunctionParams;
import cn.edu.tsinghua.iginx.engine.shared.function.manager.FunctionManager;
import cn.edu.tsinghua.iginx.engine.shared.operator.GroupBy;
import cn.edu.tsinghua.iginx.engine.shared.operator.RowTransform;
import cn.edu.tsinghua.iginx.engine.shared.operator.Sort;
import cn.edu.tsinghua.iginx.engine.shared.operator.filter.*;
import cn.edu.tsinghua.iginx.engine.shared.source.EmptySource;
import cn.edu.tsinghua.iginx.thrift.DataType;
import cn.edu.tsinghua.iginx.utils.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;

public class HeaderUtils {

Expand Down Expand Up @@ -330,4 +344,55 @@ public static void checkHeadersComparable(Header headerA, Header headerB)
}
}
}

public static RowTransform checkGroupByHeader(Header header, GroupBy groupBy) {
Set<Expression> appendExpressions = new HashSet<>();
for (Expression groupByExpr : groupBy.getGroupByExpressions()) {
String exprName = groupByExpr.getColumnName();
boolean found =
header.getFields().stream().anyMatch(field -> field.getName().equals(exprName));
if (!found) {
appendExpressions.add(groupByExpr);
}
}

if (appendExpressions.isEmpty()) {
return null;
}
return appendArithExpressions(header, new ArrayList<>(appendExpressions));
}

public static RowTransform checkSortHeader(Header header, Sort sort) {
List<Expression> sortExpressions = new ArrayList<>(sort.getSortByExpressions());
if (sortExpressions.get(0) instanceof KeyExpression) {
sortExpressions.remove(0);
}
Set<Expression> appendExpressions = new HashSet<>();
for (Expression sortExpr : sortExpressions) {
String exprName = sortExpr.getColumnName();
boolean found =
header.getFields().stream().anyMatch(field -> field.getName().equals(exprName));
if (!found) {
appendExpressions.add(sortExpr);
}
}

if (appendExpressions.isEmpty()) {
return null;
}
return appendArithExpressions(header, new ArrayList<>(appendExpressions));
}

private static RowTransform appendArithExpressions(Header header, List<Expression> expressions) {
List<FunctionCall> functionCallList = new ArrayList<>();
Function function = FunctionManager.getInstance().getFunction(ARITHMETIC_EXPR);
for (Field field : header.getFields()) {
functionCallList.add(
new FunctionCall(function, new FunctionParams(new BaseExpression(field.getName()))));
}
for (Expression expr : expressions) {
functionCallList.add(new FunctionCall(function, new FunctionParams(expr)));
}
return new RowTransform(EmptySource.EMPTY_SOURCE, functionCallList);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,16 @@ public void setAlias(String alias) {
public void accept(ExpressionVisitor visitor) {
visitor.visit(this);
}

@Override
public boolean equalExceptAlias(Expression expr) {
if (this == expr) {
return true;
}
if (expr == null || expr.getType() != ExpressionType.Base) {
return false;
}
BaseExpression that = (BaseExpression) expr;
return this.pathName.equals(that.pathName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,18 @@ public void accept(ExpressionVisitor visitor) {
leftExpression.accept(visitor);
rightExpression.accept(visitor);
}

@Override
public boolean equalExceptAlias(Expression expr) {
if (this == expr) {
return true;
}
if (expr == null || expr.getType() != ExpressionType.Binary) {
return false;
}
BinaryExpression that = (BinaryExpression) expr;
return this.leftExpression.equalExceptAlias(that.leftExpression)
&& this.rightExpression.equalExceptAlias(that.rightExpression)
&& this.op == that.op;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,16 @@ public void accept(ExpressionVisitor visitor) {
visitor.visit(this);
expression.accept(visitor);
}

@Override
public boolean equalExceptAlias(Expression expr) {
if (this == expr) {
return true;
}
if (expr == null || expr.getType() != ExpressionType.Bracket) {
return false;
}
BracketExpression that = (BracketExpression) expr;
return this.expression.equalExceptAlias(that.expression);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,38 @@ public void accept(ExpressionVisitor visitor) {
}
resultElse.accept(visitor);
}

@Override
public boolean equalExceptAlias(Expression expr) {
if (this == expr) {
return true;
}
if (expr == null || expr.getType() != ExpressionType.CaseWhen) {
return false;
}
CaseWhenExpression that = (CaseWhenExpression) expr;
if (this.conditions.size() != that.conditions.size()) {
return false;
}
for (int i = 0; i < this.conditions.size(); i++) {
if (!this.conditions.get(i).equals(that.conditions.get(i))) {
return false;
}
}
if (this.results.size() != that.results.size()) {
return false;
}
for (int i = 0; i < this.results.size(); i++) {
if (!this.results.get(i).equalExceptAlias(that.results.get(i))) {
return false;
}
}
if (this.resultElse == null && that.resultElse == null) {
return true;
}
if (this.resultElse == null || that.resultElse == null) {
return false;
}
return this.resultElse.equalExceptAlias(that.resultElse);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
*/
package cn.edu.tsinghua.iginx.engine.shared.expr;

import java.util.Arrays;

public class ConstantExpression implements Expression {

private final Object value;
Expand Down Expand Up @@ -70,4 +72,20 @@ public void setAlias(String alias) {
public void accept(ExpressionVisitor visitor) {
visitor.visit(this);
}

@Override
public boolean equalExceptAlias(Expression expr) {
if (this == expr) {
return true;
}
if (expr == null || expr.getType() != ExpressionType.Constant) {
return false;
}
ConstantExpression that = (ConstantExpression) expr;
if (value instanceof byte[]) {
return Arrays.equals((byte[]) value, (byte[]) that.value);
} else {
return this.value.equals(that.value);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ public interface Expression {

void accept(ExpressionVisitor visitor);

boolean equalExceptAlias(Expression expr);

enum ExpressionType {
Bracket,
Binary,
Expand Down
Loading
Loading