From 89785a9f2a99ccb53b1ca3313432d6889f20ec53 Mon Sep 17 00:00:00 2001 From: jzl18thu Date: Wed, 16 Oct 2024 03:50:53 +0800 Subject: [PATCH 1/6] feat(sql): GROUP BY expr && ORDER BY expr MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1.支持对GROUP BY和ORDER BY中的列使用RowToRow表达式 2.支持GROUP BY和ORDER BY中的列与SELECT子句中的别名进行匹配 --- .../antlr4/cn/edu/tsinghua/iginx/sql/Sql.g4 | 8 +- .../logical/generator/QueryGenerator.java | 6 +- .../engine/logical/utils/OperatorUtils.java | 10 +- .../naive/NaiveOperatorMemoryExecutor.java | 10 + .../stream/StreamOperatorMemoryExecutor.java | 13 +- .../memory/execute/utils/HeaderUtils.java | 67 ++++++ .../engine/shared/expr/BaseExpression.java | 12 ++ .../engine/shared/expr/BinaryExpression.java | 14 ++ .../engine/shared/expr/BracketExpression.java | 12 ++ .../shared/expr/CaseWhenExpression.java | 34 +++ .../shared/expr/ConstantExpression.java | 18 ++ .../iginx/engine/shared/expr/Expression.java | 2 + .../shared/expr/FromValueExpression.java | 16 +- .../engine/shared/expr/FuncExpression.java | 37 ++++ .../engine/shared/expr/KeyExpression.java | 8 + .../shared/expr/MultipleExpression.java | 22 ++ .../shared/expr/SequenceExpression.java | 12 ++ .../engine/shared/expr/UnaryExpression.java | 12 ++ .../shared/function/FunctionParams.java | 5 + .../iginx/engine/shared/operator/GroupBy.java | 40 +++- .../iginx/engine/shared/operator/Sort.java | 38 +++- .../shared/operator/filter/ExprFilter.java | 20 ++ .../tsinghua/iginx/sql/IginXSqlVisitor.java | 203 ++++++++++++++---- .../select/BinarySelectStatement.java | 5 + .../sql/statement/select/SelectStatement.java | 14 +- .../select/UnarySelectStatement.java | 127 ++--------- .../select/subclause/GroupByClause.java | 13 +- .../select/subclause/OrderByClause.java | 30 ++- .../iginx/sql/utils/ExpressionUtils.java | 113 ++++++++++ .../AbstractOperatorMemoryExecutorTest.java | 7 +- .../cn/edu/tsinghua/iginx/sql/ParseTest.java | 13 +- .../optimizer/rules/ColumnPruningRule.java | 25 ++- .../integration/func/sql/SQLSessionIT.java | 58 +++++ .../iginx/integration/func/udf/UDFIT.java | 58 +++++ 34 files changed, 884 insertions(+), 198 deletions(-) diff --git a/antlr/src/main/antlr4/cn/edu/tsinghua/iginx/sql/Sql.g4 b/antlr/src/main/antlr4/cn/edu/tsinghua/iginx/sql/Sql.g4 index 6b664d9655..74ad2f79e0 100644 --- a/antlr/src/main/antlr4/cn/edu/tsinghua/iginx/sql/Sql.g4 +++ b/antlr/src/main/antlr4/cn/edu/tsinghua/iginx/sql/Sql.g4 @@ -148,11 +148,11 @@ functionName ; caseSpecification - : simipleCase + : simpleCase | searchedCase ; -simipleCase +simpleCase : CASE expression simpleWhenClause (simpleWhenClause)* elseClause? END ; @@ -301,7 +301,7 @@ specialClause ; groupByClause - : GROUP BY path (COMMA path)* + : GROUP BY expression (COMMA expression)* ; havingClause @@ -313,7 +313,7 @@ orderByClause ; orderItem - : path (DESC | ASC)? + : expression (DESC | ASC)? ; downsampleClause diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/logical/generator/QueryGenerator.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/logical/generator/QueryGenerator.java index 7f2961bdb7..8debde8f66 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/logical/generator/QueryGenerator.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/logical/generator/QueryGenerator.java @@ -582,7 +582,7 @@ private static Operator buildLimit(SelectStatement selectStatement, Operator roo * @return 添加了Sort操作符的根节点;如果没有Order By子句,返回原根节点 */ private static Operator buildOrderByPaths(SelectStatement selectStatement, Operator root) { - if (selectStatement.getOrderByPaths().isEmpty()) { + if (selectStatement.getOrderByExpressions().isEmpty()) { return root; } List sortTypes = new ArrayList<>(); @@ -590,7 +590,7 @@ private static Operator buildOrderByPaths(SelectStatement selectStatement, Opera .getAscendingList() .forEach( isAscending -> sortTypes.add(isAscending ? Sort.SortType.ASC : Sort.SortType.DESC)); - return new Sort(new OperatorSource(root), selectStatement.getOrderByPaths(), sortTypes); + return new Sort(new OperatorSource(root), selectStatement.getOrderByExpressions(), sortTypes); } /** @@ -661,7 +661,7 @@ private Operator buildGroupByQuery(UnarySelectStatement selectStatement, Operato List functionCallList = getFunctionCallList(selectStatement, MappingType.SetMapping); return new GroupBy( - new OperatorSource(root), selectStatement.getGroupByPaths(), functionCallList); + new OperatorSource(root), selectStatement.getGroupByExpressions(), functionCallList); } /** diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/logical/utils/OperatorUtils.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/logical/utils/OperatorUtils.java index 2edc0833cd..aab9f1a990 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/logical/utils/OperatorUtils.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/logical/utils/OperatorUtils.java @@ -26,6 +26,7 @@ import static cn.edu.tsinghua.iginx.engine.shared.operator.type.OperatorType.isUnaryOperator; import cn.edu.tsinghua.iginx.engine.shared.expr.BaseExpression; +import cn.edu.tsinghua.iginx.engine.shared.expr.Expression; import cn.edu.tsinghua.iginx.engine.shared.function.FunctionCall; import cn.edu.tsinghua.iginx.engine.shared.function.FunctionParams; import cn.edu.tsinghua.iginx.engine.shared.function.FunctionUtils; @@ -315,18 +316,19 @@ private static Operator pushDownApply(Operator root, List correlatedVari root = new GroupBy( new OperatorSource(pushDownApply(apply, correlatedVariables)), - correlatedVariables, + correlatedVariables.stream().map(BaseExpression::new).collect(Collectors.toList()), setTransform.getFunctionCallList()); break; case GroupBy: GroupBy groupBy = (GroupBy) operatorB; apply.setSourceB(groupBy.getSource()); - List groupByCols = groupBy.getGroupByCols(); - groupByCols.addAll(correlatedVariables); + List groupByExpressions = groupBy.getGroupByExpressions(); + groupByExpressions.addAll( + correlatedVariables.stream().map(BaseExpression::new).collect(Collectors.toList())); root = new GroupBy( new OperatorSource(pushDownApply(apply, correlatedVariables)), - groupByCols, + groupByExpressions, groupBy.getFunctionCallList()); break; case Rename: diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/naive/NaiveOperatorMemoryExecutor.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/naive/NaiveOperatorMemoryExecutor.java index 900ad3ef54..5e6c9d7ac6 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/naive/NaiveOperatorMemoryExecutor.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/naive/NaiveOperatorMemoryExecutor.java @@ -259,6 +259,11 @@ private RowStream executeSelect(Select select, Table table) throws PhysicalExcep } private RowStream executeSort(Sort sort, Table table) throws PhysicalException { + RowTransform preRowTransform = HeaderUtils.checkSortHeader(table.getHeader(), sort); + if (preRowTransform != null) { + table = transformToTable(executeRowTransform(preRowTransform, table)); + } + List ascendingList = sort.getAscendingList(); RowUtils.sortRows(table.getRows(), ascendingList, sort.getSortByCols()); return table; @@ -481,6 +486,11 @@ private RowStream executeAddSchemaPrefix(AddSchemaPrefix addSchemaPrefix, Table } private RowStream executeGroupBy(GroupBy groupBy, Table table) throws PhysicalException { + RowTransform preRowTransform = HeaderUtils.checkGroupByHeader(table.getHeader(), groupBy); + if (preRowTransform != null) { + table = transformToTable(executeRowTransform(preRowTransform, table)); + } + List rows = RowUtils.cacheGroupByResult(groupBy, table); if (rows.isEmpty()) { return Table.EMPTY_TABLE; diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/stream/StreamOperatorMemoryExecutor.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/stream/StreamOperatorMemoryExecutor.java index 6ef5c54896..b0ae4a24a4 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/stream/StreamOperatorMemoryExecutor.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/stream/StreamOperatorMemoryExecutor.java @@ -24,6 +24,7 @@ import cn.edu.tsinghua.iginx.engine.physical.exception.UnexpectedOperatorException; import cn.edu.tsinghua.iginx.engine.physical.memory.execute.OperatorMemoryExecutor; import cn.edu.tsinghua.iginx.engine.physical.memory.execute.Table; +import cn.edu.tsinghua.iginx.engine.physical.memory.execute.utils.HeaderUtils; import cn.edu.tsinghua.iginx.engine.physical.memory.execute.utils.RowUtils; import cn.edu.tsinghua.iginx.engine.shared.Constants; import cn.edu.tsinghua.iginx.engine.shared.RequestContext; @@ -196,6 +197,11 @@ private RowStream executeSelect(Select select, RowStream stream) { } private RowStream executeSort(Sort sort, RowStream stream) throws PhysicalException { + RowTransform preRowTransform = HeaderUtils.checkSortHeader(stream.getHeader(), sort); + if (preRowTransform != null) { + stream = executeRowTransform(preRowTransform, stream); + } + return new SortLazyStream(sort, stream); } @@ -268,7 +274,12 @@ private RowStream executeAddSchemaPrefix(AddSchemaPrefix addSchemaPrefix, RowStr return new AddSchemaPrefixLazyStream(addSchemaPrefix, stream); } - private RowStream executeGroupBy(GroupBy groupBy, RowStream stream) { + private RowStream executeGroupBy(GroupBy groupBy, RowStream stream) throws PhysicalException { + RowTransform preRowTransform = HeaderUtils.checkGroupByHeader(stream.getHeader(), groupBy); + if (preRowTransform != null) { + stream = executeRowTransform(preRowTransform, stream); + } + return new GroupByLazyStream(groupBy, stream); } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/utils/HeaderUtils.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/utils/HeaderUtils.java index b569ac7758..dfd68ba422 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/utils/HeaderUtils.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/utils/HeaderUtils.java @@ -18,6 +18,7 @@ package cn.edu.tsinghua.iginx.engine.physical.memory.execute.utils; +import static cn.edu.tsinghua.iginx.engine.shared.function.system.ArithmeticExpr.ARITHMETIC_EXPR; import static cn.edu.tsinghua.iginx.engine.shared.function.system.utils.ValueUtils.isNumericType; import static cn.edu.tsinghua.iginx.sql.SQLConstant.DOT; import static cn.edu.tsinghua.iginx.thrift.DataType.BOOLEAN; @@ -27,16 +28,31 @@ import cn.edu.tsinghua.iginx.engine.physical.exception.PhysicalException; import cn.edu.tsinghua.iginx.engine.shared.data.read.Field; import cn.edu.tsinghua.iginx.engine.shared.data.read.Header; +import cn.edu.tsinghua.iginx.engine.shared.expr.BaseExpression; +import cn.edu.tsinghua.iginx.engine.shared.expr.Expression; +import cn.edu.tsinghua.iginx.engine.shared.expr.KeyExpression; +import cn.edu.tsinghua.iginx.engine.shared.function.Function; +import cn.edu.tsinghua.iginx.engine.shared.function.FunctionCall; +import cn.edu.tsinghua.iginx.engine.shared.function.FunctionParams; +import cn.edu.tsinghua.iginx.engine.shared.function.manager.FunctionManager; +import cn.edu.tsinghua.iginx.engine.shared.operator.GroupBy; +import cn.edu.tsinghua.iginx.engine.shared.operator.RowTransform; +import cn.edu.tsinghua.iginx.engine.shared.operator.Sort; import cn.edu.tsinghua.iginx.engine.shared.operator.filter.*; +import cn.edu.tsinghua.iginx.engine.shared.source.EmptySource; import cn.edu.tsinghua.iginx.thrift.DataType; import cn.edu.tsinghua.iginx.utils.Pair; import java.util.ArrayList; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Objects; +import java.util.Set; public class HeaderUtils { + private static final FunctionManager functionManager = FunctionManager.getInstance(); + public static Header constructNewHead(Header header, String markColumn) { List fields = new ArrayList<>(header.getFields()); fields.add(new Field(markColumn, BOOLEAN)); @@ -329,4 +345,55 @@ public static void checkHeadersComparable(Header headerA, Header headerB) } } } + + public static RowTransform checkGroupByHeader(Header header, GroupBy groupBy) { + Set appendExpressions = new HashSet<>(); + for (Expression groupByExpr : groupBy.getGroupByExpressions()) { + String exprName = groupByExpr.getColumnName(); + boolean found = + header.getFields().stream().anyMatch(field -> field.getName().equals(exprName)); + if (!found) { + appendExpressions.add(groupByExpr); + } + } + + if (appendExpressions.isEmpty()) { + return null; + } + return appendArithExpressions(header, new ArrayList<>(appendExpressions)); + } + + public static RowTransform checkSortHeader(Header header, Sort sort) { + List sortExpressions = new ArrayList<>(sort.getSortByExpressions()); + if (sortExpressions.get(0) instanceof KeyExpression) { + sortExpressions.remove(0); + } + Set appendExpressions = new HashSet<>(); + for (Expression sortExpr : sortExpressions) { + String exprName = sortExpr.getColumnName(); + boolean found = + header.getFields().stream().anyMatch(field -> field.getName().equals(exprName)); + if (!found) { + appendExpressions.add(sortExpr); + } + } + + if (appendExpressions.isEmpty()) { + return null; + } + return appendArithExpressions(header, new ArrayList<>(appendExpressions)); + } + + private static RowTransform appendArithExpressions(Header header, List expressions) { + List functionCallList = new ArrayList<>(); + Function function = functionManager.getFunction(ARITHMETIC_EXPR); + for (Field field : header.getFields()) { + functionCallList.add( + new FunctionCall(function, new FunctionParams(new BaseExpression(field.getName())))); + } + for (Expression expr : expressions) { + functionCallList.add(new FunctionCall(function, new FunctionParams(expr))); + } + return new RowTransform(EmptySource.EMPTY_SOURCE, functionCallList); + } } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/BaseExpression.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/BaseExpression.java index 852af9629d..b1cccc35b7 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/BaseExpression.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/BaseExpression.java @@ -69,4 +69,16 @@ public void setAlias(String alias) { public void accept(ExpressionVisitor visitor) { visitor.visit(this); } + + @Override + public boolean equalExceptAlias(Expression expr) { + if (this == expr) { + return true; + } + if (expr == null || expr.getType() != ExpressionType.Base) { + return false; + } + BaseExpression that = (BaseExpression) expr; + return this.pathName.equals(that.pathName); + } } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/BinaryExpression.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/BinaryExpression.java index 6a6d8009b0..f7d0e11756 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/BinaryExpression.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/BinaryExpression.java @@ -96,4 +96,18 @@ public void accept(ExpressionVisitor visitor) { leftExpression.accept(visitor); rightExpression.accept(visitor); } + + @Override + public boolean equalExceptAlias(Expression expr) { + if (this == expr) { + return true; + } + if (expr == null || expr.getType() != ExpressionType.Binary) { + return false; + } + BinaryExpression that = (BinaryExpression) expr; + return this.leftExpression.equalExceptAlias(that.leftExpression) + && this.rightExpression.equalExceptAlias(that.rightExpression) + && this.op == that.op; + } } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/BracketExpression.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/BracketExpression.java index 9695abea76..9dd11763e2 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/BracketExpression.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/BracketExpression.java @@ -70,4 +70,16 @@ public void accept(ExpressionVisitor visitor) { visitor.visit(this); expression.accept(visitor); } + + @Override + public boolean equalExceptAlias(Expression expr) { + if (this == expr) { + return true; + } + if (expr == null || expr.getType() != ExpressionType.Bracket) { + return false; + } + BracketExpression that = (BracketExpression) expr; + return this.expression.equalExceptAlias(that.expression); + } } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/CaseWhenExpression.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/CaseWhenExpression.java index 9a2eb0ba41..e0e078adbc 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/CaseWhenExpression.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/CaseWhenExpression.java @@ -105,4 +105,38 @@ public void accept(ExpressionVisitor visitor) { } resultElse.accept(visitor); } + + @Override + public boolean equalExceptAlias(Expression expr) { + if (this == expr) { + return true; + } + if (expr == null || expr.getType() != ExpressionType.CaseWhen) { + return false; + } + CaseWhenExpression that = (CaseWhenExpression) expr; + if (this.conditions.size() != that.conditions.size()) { + return false; + } + for (int i = 0; i < this.conditions.size(); i++) { + if (!this.conditions.get(i).equals(that.conditions.get(i))) { + return false; + } + } + if (this.results.size() != that.results.size()) { + return false; + } + for (int i = 0; i < this.results.size(); i++) { + if (!this.results.get(i).equalExceptAlias(that.results.get(i))) { + return false; + } + } + if (this.resultElse == null && that.resultElse == null) { + return true; + } + if (this.resultElse == null || that.resultElse == null) { + return false; + } + return this.resultElse.equalExceptAlias(that.resultElse); + } } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/ConstantExpression.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/ConstantExpression.java index 86015e072a..d5ff9e3b40 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/ConstantExpression.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/ConstantExpression.java @@ -18,6 +18,8 @@ package cn.edu.tsinghua.iginx.engine.shared.expr; +import java.util.Arrays; + public class ConstantExpression implements Expression { private final Object value; @@ -69,4 +71,20 @@ public void setAlias(String alias) { public void accept(ExpressionVisitor visitor) { visitor.visit(this); } + + @Override + public boolean equalExceptAlias(Expression expr) { + if (this == expr) { + return true; + } + if (expr == null || expr.getType() != ExpressionType.Constant) { + return false; + } + ConstantExpression that = (ConstantExpression) expr; + if (value instanceof byte[]) { + return Arrays.equals((byte[]) value, (byte[]) that.value); + } else { + return this.value.equals(that.value); + } + } } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/Expression.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/Expression.java index 1569edf8ab..65cb943dd1 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/Expression.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/Expression.java @@ -32,6 +32,8 @@ public interface Expression { void accept(ExpressionVisitor visitor); + boolean equalExceptAlias(Expression expr); + enum ExpressionType { Bracket, Binary, diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/FromValueExpression.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/FromValueExpression.java index 9134b21561..7482d2a349 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/FromValueExpression.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/FromValueExpression.java @@ -34,7 +34,7 @@ public SelectStatement getSubStatement() { @Override public String getColumnName() { - return ""; + return "*"; } @Override @@ -49,7 +49,7 @@ public boolean hasAlias() { @Override public String getAlias() { - return null; + return ""; } @Override @@ -59,4 +59,16 @@ public void setAlias(String alias) {} public void accept(ExpressionVisitor visitor) { visitor.visit(this); } + + @Override + public boolean equalExceptAlias(Expression expr) { + if (this == expr) { + return true; + } + if (expr == null || expr.getType() != ExpressionType.FromValue) { + return false; + } + FromValueExpression that = (FromValueExpression) expr; + return this.subStatement.equals(that.subStatement); + } } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/FuncExpression.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/FuncExpression.java index 88e364bcf1..a1d9c91e38 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/FuncExpression.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/FuncExpression.java @@ -127,4 +127,41 @@ public void accept(ExpressionVisitor visitor) { visitor.visit(this); expressions.forEach(e -> e.accept(visitor)); } + + @Override + public boolean equalExceptAlias(Expression expr) { + if (this == expr) { + return true; + } + if (expr == null || expr.getType() != ExpressionType.Function) { + return false; + } + FuncExpression that = (FuncExpression) expr; + + if (this.isPyUDF != that.isPyUDF) { + return false; + } + if (this.isPyUDF) { + if (!this.funcName.equals(that.funcName)) { + return false; + } + } else { + if (!this.funcName.equalsIgnoreCase(that.funcName)) { + return false; + } + } + + if (this.getExpressions().size() != that.getExpressions().size()) { + return false; + } + for (int i = 0; i < this.getExpressions().size(); i++) { + if (!this.getExpressions().get(i).equalExceptAlias(that.getExpressions().get(i))) { + return false; + } + } + + return this.args.equals(that.args) + && this.kvargs.equals(that.kvargs) + && this.isDistinct == that.isDistinct; + } } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/KeyExpression.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/KeyExpression.java index 474fd4f801..42c2dbfacc 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/KeyExpression.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/KeyExpression.java @@ -58,4 +58,12 @@ public void setAlias(String alias) { public void accept(ExpressionVisitor visitor) { visitor.visit(this); } + + @Override + public boolean equalExceptAlias(Expression expr) { + if (this == expr) { + return true; + } + return expr != null && expr.getType() == ExpressionType.Key; + } } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/MultipleExpression.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/MultipleExpression.java index 8f73f354f4..3cc64c52f0 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/MultipleExpression.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/MultipleExpression.java @@ -110,4 +110,26 @@ public void accept(ExpressionVisitor visitor) { visitor.visit(this); children.forEach(e -> e.accept(visitor)); } + + @Override + public boolean equalExceptAlias(Expression expr) { + if (this == expr) { + return true; + } + if (expr == null || expr.getType() != ExpressionType.Multiple) { + return false; + } + MultipleExpression that = (MultipleExpression) expr; + + if (this.getChildren().size() != that.getChildren().size()) { + return false; + } + for (int i = 0; i < this.getChildren().size(); i++) { + if (!this.getChildren().get(i).equalExceptAlias(that.getChildren().get(i))) { + return false; + } + } + + return this.ops.equals(that.ops); + } } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/SequenceExpression.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/SequenceExpression.java index 50fa39e62d..1e044eab79 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/SequenceExpression.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/SequenceExpression.java @@ -76,4 +76,16 @@ public void setAlias(String alias) { public void accept(ExpressionVisitor visitor) { visitor.visit(this); } + + @Override + public boolean equalExceptAlias(Expression expr) { + if (this == expr) { + return true; + } + if (expr == null || expr.getType() != ExpressionType.Sequence) { + return false; + } + SequenceExpression that = (SequenceExpression) expr; + return this.start == that.start && this.increment == that.increment; + } } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/UnaryExpression.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/UnaryExpression.java index de051e6478..5bab7b55be 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/UnaryExpression.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/expr/UnaryExpression.java @@ -76,4 +76,16 @@ public void accept(ExpressionVisitor visitor) { visitor.visit(this); expression.accept(visitor); } + + @Override + public boolean equalExceptAlias(Expression expr) { + if (this == expr) { + return true; + } + if (expr == null || expr.getType() != ExpressionType.Unary) { + return false; + } + UnaryExpression that = (UnaryExpression) expr; + return this.expression.equalExceptAlias(that.expression) && this.operator.equals(that.operator); + } } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/function/FunctionParams.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/function/FunctionParams.java index 510db84e39..6a04e755f1 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/function/FunctionParams.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/function/FunctionParams.java @@ -21,6 +21,7 @@ import cn.edu.tsinghua.iginx.engine.physical.memory.execute.utils.ExprUtils; import cn.edu.tsinghua.iginx.engine.shared.expr.Expression; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -38,6 +39,10 @@ public class FunctionParams { private boolean isDistinct; + public FunctionParams(Expression expression) { + this(Collections.singletonList(expression), null, null, false); + } + public FunctionParams(List expressions) { this(expressions, null, null, false); } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/operator/GroupBy.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/operator/GroupBy.java index 1af269b5f9..620dc60352 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/operator/GroupBy.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/operator/GroupBy.java @@ -18,27 +18,39 @@ package cn.edu.tsinghua.iginx.engine.shared.operator; +import cn.edu.tsinghua.iginx.engine.physical.memory.execute.utils.ExprUtils; +import cn.edu.tsinghua.iginx.engine.shared.expr.Expression; import cn.edu.tsinghua.iginx.engine.shared.function.FunctionCall; import cn.edu.tsinghua.iginx.engine.shared.operator.type.OperatorType; import cn.edu.tsinghua.iginx.engine.shared.source.Source; import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; public class GroupBy extends AbstractUnaryOperator { + private final List groupByExpressions; + private final List groupByCols; private final List functionCallList; - public GroupBy(Source source, List groupByCols, List functionCallList) { + public GroupBy( + Source source, List groupByExpressions, List functionCallList) { super(OperatorType.GroupBy, source); - if (groupByCols == null || groupByCols.isEmpty()) { + if (groupByExpressions == null || groupByExpressions.isEmpty()) { throw new IllegalArgumentException("groupByCols shouldn't be null"); } - this.groupByCols = groupByCols; + this.groupByExpressions = groupByExpressions; + this.groupByCols = + groupByExpressions.stream().map(Expression::getColumnName).collect(Collectors.toList()); this.functionCallList = functionCallList; } + public List getGroupByExpressions() { + return groupByExpressions; + } + public List getGroupByCols() { return groupByCols; } @@ -49,13 +61,21 @@ public List getFunctionCallList() { @Override public Operator copy() { + List copyGroupByExpressions = new ArrayList<>(groupByExpressions.size()); + for (Expression expression : groupByExpressions) { + copyGroupByExpressions.add(ExprUtils.copy(expression)); + } return new GroupBy( - getSource().copy(), new ArrayList<>(groupByCols), new ArrayList<>(functionCallList)); + getSource().copy(), copyGroupByExpressions, new ArrayList<>(functionCallList)); } @Override public UnaryOperator copyWithSource(Source source) { - return new GroupBy(source, new ArrayList<>(groupByCols), new ArrayList<>(functionCallList)); + List copyGroupByExpressions = new ArrayList<>(groupByExpressions.size()); + for (Expression expression : groupByExpressions) { + copyGroupByExpressions.add(ExprUtils.copy(expression)); + } + return new GroupBy(source, copyGroupByExpressions, new ArrayList<>(functionCallList)); } public boolean isDistinct() { @@ -95,6 +115,14 @@ public boolean equals(Object object) { return false; } GroupBy that = (GroupBy) object; - return groupByCols.equals(that.groupByCols) && functionCallList.equals(that.functionCallList); + if (this.groupByExpressions.size() != that.groupByExpressions.size()) { + return false; + } + for (int i = 0; i < this.groupByExpressions.size(); i++) { + if (!this.groupByExpressions.get(i).equalExceptAlias(that.groupByExpressions.get(i))) { + return false; + } + } + return functionCallList.equals(that.functionCallList); } } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/operator/Sort.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/operator/Sort.java index 84cd04ed76..9ed6ede4ce 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/operator/Sort.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/operator/Sort.java @@ -17,6 +17,8 @@ */ package cn.edu.tsinghua.iginx.engine.shared.operator; +import cn.edu.tsinghua.iginx.engine.physical.memory.execute.utils.ExprUtils; +import cn.edu.tsinghua.iginx.engine.shared.expr.Expression; import cn.edu.tsinghua.iginx.engine.shared.operator.type.OperatorType; import cn.edu.tsinghua.iginx.engine.shared.source.Source; import java.util.ArrayList; @@ -25,22 +27,30 @@ public class Sort extends AbstractUnaryOperator { + private final List sortByExpressions; + private final List sortByCols; private final List sortTypes; - public Sort(Source source, List sortByCols, List sortTypes) { + public Sort(Source source, List sortByExpressions, List sortTypes) { super(OperatorType.Sort, source); - if (sortByCols == null || sortByCols.isEmpty()) { + if (sortByExpressions == null || sortByExpressions.isEmpty()) { throw new IllegalArgumentException("sortBy shouldn't be null"); } if (sortTypes == null || sortTypes.isEmpty()) { throw new IllegalArgumentException("sortType shouldn't be null"); } - this.sortByCols = sortByCols; + this.sortByExpressions = sortByExpressions; + this.sortByCols = + sortByExpressions.stream().map(Expression::getColumnName).collect(Collectors.toList()); this.sortTypes = sortTypes; } + public List getSortByExpressions() { + return sortByExpressions; + } + public List getSortByCols() { return sortByCols; } @@ -59,12 +69,20 @@ public List getAscendingList() { @Override public Operator copy() { - return new Sort(getSource().copy(), new ArrayList<>(sortByCols), new ArrayList<>(sortTypes)); + List copySortByExpressions = new ArrayList<>(sortByExpressions.size()); + for (Expression expression : sortByExpressions) { + copySortByExpressions.add(ExprUtils.copy(expression)); + } + return new Sort(getSource().copy(), copySortByExpressions, new ArrayList<>(sortTypes)); } @Override public UnaryOperator copyWithSource(Source source) { - return new Sort(source, new ArrayList<>(sortByCols), new ArrayList<>(sortTypes)); + List copySortByExpressions = new ArrayList<>(sortByExpressions.size()); + for (Expression expression : sortByExpressions) { + copySortByExpressions.add(ExprUtils.copy(expression)); + } + return new Sort(source, copySortByExpressions, new ArrayList<>(sortTypes)); } public enum SortType { @@ -89,6 +107,14 @@ public boolean equals(Object object) { return false; } Sort sort = (Sort) object; - return sortByCols.equals(sort.sortByCols) && sortTypes.equals(sort.sortTypes); + if (this.sortByExpressions.size() != sort.sortByExpressions.size()) { + return false; + } + for (int i = 0; i < this.sortByExpressions.size(); i++) { + if (!this.sortByExpressions.get(i).equalExceptAlias(sort.sortByExpressions.get(i))) { + return false; + } + } + return sortTypes.equals(sort.sortTypes); } } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/operator/filter/ExprFilter.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/operator/filter/ExprFilter.java index f479d87704..262d7ee5bf 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/operator/filter/ExprFilter.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/operator/filter/ExprFilter.java @@ -20,6 +20,7 @@ import cn.edu.tsinghua.iginx.engine.physical.memory.execute.utils.ExprUtils; import cn.edu.tsinghua.iginx.engine.shared.expr.Expression; +import java.util.Objects; public class ExprFilter implements Filter { @@ -80,4 +81,23 @@ public Filter copy() { public String toString() { return expressionA.getColumnName() + " " + Op.op2Str(op) + " " + expressionB.getColumnName(); } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ExprFilter that = (ExprFilter) o; + return expressionA.equalExceptAlias(that.expressionA) + && op == that.op + && expressionB.equalExceptAlias(that.expressionB); + } + + @Override + public int hashCode() { + return Objects.hash(type, expressionA, expressionB, op); + } } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/sql/IginXSqlVisitor.java b/core/src/main/java/cn/edu/tsinghua/iginx/sql/IginXSqlVisitor.java index 0093d529b9..dcbaf7e5a3 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/sql/IginXSqlVisitor.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/sql/IginXSqlVisitor.java @@ -52,6 +52,7 @@ import cn.edu.tsinghua.iginx.engine.shared.file.write.ExportCsv; import cn.edu.tsinghua.iginx.engine.shared.file.write.ExportFile; import cn.edu.tsinghua.iginx.engine.shared.function.FunctionUtils; +import cn.edu.tsinghua.iginx.engine.shared.function.MappingType; import cn.edu.tsinghua.iginx.engine.shared.operator.filter.*; import cn.edu.tsinghua.iginx.engine.shared.operator.tag.AndTagFilter; import cn.edu.tsinghua.iginx.engine.shared.operator.tag.BasePreciseTagFilter; @@ -102,6 +103,7 @@ import cn.edu.tsinghua.iginx.sql.SqlParser.OrPreciseExpressionContext; import cn.edu.tsinghua.iginx.sql.SqlParser.OrTagExpressionContext; import cn.edu.tsinghua.iginx.sql.SqlParser.OrderByClauseContext; +import cn.edu.tsinghua.iginx.sql.SqlParser.OrderItemContext; import cn.edu.tsinghua.iginx.sql.SqlParser.ParamContext; import cn.edu.tsinghua.iginx.sql.SqlParser.PathContext; import cn.edu.tsinghua.iginx.sql.SqlParser.PreciseTagExpressionContext; @@ -129,7 +131,7 @@ import cn.edu.tsinghua.iginx.sql.SqlParser.ShowReplicationStatementContext; import cn.edu.tsinghua.iginx.sql.SqlParser.ShowRulesStatementContext; import cn.edu.tsinghua.iginx.sql.SqlParser.ShowSessionIDStatementContext; -import cn.edu.tsinghua.iginx.sql.SqlParser.SimipleCaseContext; +import cn.edu.tsinghua.iginx.sql.SqlParser.SimpleCaseContext; import cn.edu.tsinghua.iginx.sql.SqlParser.SimpleWhenClauseContext; import cn.edu.tsinghua.iginx.sql.SqlParser.SpecialClauseContext; import cn.edu.tsinghua.iginx.sql.SqlParser.SqlStatementContext; @@ -159,6 +161,7 @@ import cn.edu.tsinghua.iginx.sql.utils.ExpressionUtils; import cn.edu.tsinghua.iginx.thrift.*; import cn.edu.tsinghua.iginx.utils.Pair; +import cn.edu.tsinghua.iginx.utils.StringUtils; import cn.edu.tsinghua.iginx.utils.TimeUtils; import java.util.*; import java.util.stream.Collectors; @@ -1179,8 +1182,8 @@ private Expression parseBaseExpression( private Expression parseCaseWhenExpression( CaseSpecificationContext ctx, UnarySelectStatement selectStatement) { - if (ctx.simipleCase() != null) { - return parseSimpleCase(ctx.simipleCase(), selectStatement); + if (ctx.simpleCase() != null) { + return parseSimpleCase(ctx.simpleCase(), selectStatement); } else if (ctx.searchedCase() != null) { return parseSearchedCase(ctx.searchedCase(), selectStatement); } else { @@ -1189,7 +1192,7 @@ private Expression parseCaseWhenExpression( } private CaseWhenExpression parseSimpleCase( - SimipleCaseContext ctx, UnarySelectStatement selectStatement) { + SimpleCaseContext ctx, UnarySelectStatement selectStatement) { List conditions = new ArrayList<>(); List results = new ArrayList<>(); Expression leftExpr = parseExpression(ctx.expression(), selectStatement).get(0); @@ -1327,37 +1330,92 @@ private void parseDownsampleClause( } private void parseGroupByClause(GroupByClauseContext ctx, UnarySelectStatement selectStatement) { - if (ExprUtils.hasCaseWhen(selectStatement.getExpressions())) { - throw new SQLParserException( - "CASE WHEN is not supported to be selected when sql has GROUP BY."); - } selectStatement.setHasGroupBy(true); + for (ExpressionContext exprCtx : ctx.expression()) { + if (exprCtx.subquery() != null) { + throw new SQLParserException("Subquery is not supported in GROUP BY columns."); + } - ctx.path() - .forEach( - pathContext -> { - String path = parsePath(pathContext); - // 如果查询语句的FROM子句只有一个部分且FROM一个前缀,则GROUP BY后的path只用写出后缀 - if (selectStatement.isFromSinglePath()) { - path = selectStatement.getFromPart(0).getPrefix() + SQLConstant.DOT + path; - } - if (path.contains("*")) { - throw new SQLParserException( - String.format("GROUP BY path '%s' has '*', which is not supported.", path)); - } - selectStatement.setGroupByPath(path); + Expression expr = parseExpression(exprCtx, selectStatement, false).get(0); + + if (expr instanceof BaseExpression) { + BaseExpression baseExpr = (BaseExpression) expr; + Set groupByExprSet = new HashSet<>(); + String path = baseExpr.getPathName(); + if (path.contains("*")) { + throw new SQLParserException( + String.format("GROUP BY column '%s' has '*', which is not supported.", path)); + } + // 删去在解析expression时加上的前缀 + if (selectStatement.isFromSinglePath()) { + path = path.replaceFirst(selectStatement.getFromPart(0).getPrefix() + "\\.", ""); + } + + for (Expression selectExpr : selectStatement.getExpressions()) { + if (selectExpr.equalExceptAlias(baseExpr)) { // 匹配select表达式 + groupByExprSet.add(selectExpr); + continue; + } + if (selectExpr.getAlias().equals(path)) { // 匹配select表达式别名 + groupByExprSet.add(selectExpr); + } + } + + // 匹配到了多个select表达式 + if (groupByExprSet.size() > 1) { + throw new SQLParserException(String.format("GROUP BY column '%s' is ambiguous.", path)); + } + + // GROUP BY的表达式没有出现在SELECT子句中 + if (groupByExprSet.isEmpty()) { + selectStatement.setGroupByExpr(expr); + String originPath = selectStatement.getOriginPath(path); + if (originPath != null) { + selectStatement.addGroupByPath(originPath); + } + } else { + selectStatement.setGroupByExpr(groupByExprSet.iterator().next()); + } + } else { + MappingType type = ExpressionUtils.getExprMappingType(expr); + if (type == MappingType.SetMapping || type == MappingType.Mapping) { + throw new SQLParserException("GROUP BY column can not use SetToSet/SetToRow functions."); + } + selectStatement.setGroupByExpr(expr); + + List baseExpressions = + ExpressionUtils.getBaseExpressionList(Collections.singletonList(expr), false); + baseExpressions.forEach( + baseExpression -> { + String path = baseExpression.getPathName(); String originPath = selectStatement.getOriginPath(path); if (originPath != null) { selectStatement.addGroupByPath(originPath); } }); + } + } selectStatement - .getBaseExpressionList(true) + .getExpressions() .forEach( - expr -> { - if (!selectStatement.getGroupByPaths().contains(expr.getPathName())) { - throw new SQLParserException("Selected path must exist in group by clause."); + selectExpr -> { + if (ExpressionUtils.getExprMappingType(selectExpr) == MappingType.RowMapping) { + boolean foundInGroupBy = false; + for (int i = 0; i < selectStatement.getGroupByExpressions().size(); i++) { + Expression groupByExpr = selectStatement.getGroupByExpressions().get(i); + if (selectExpr.equalExceptAlias(groupByExpr)) { + selectStatement.getGroupByExpressions().set(i, selectExpr); + foundInGroupBy = true; + break; + } + } + if (!foundInGroupBy) { + throw new SQLParserException( + String.format( + "Selected expression '%s' does not exist in GROUP BY clause.", + selectExpr.getColumnName())); + } } }); } @@ -1396,7 +1454,7 @@ private Pair getLimitAndOffsetFromCtx(LimitClauseContext ctx) private void parseOrderByClause(OrderByClauseContext ctx, SelectStatement selectStatement) { if (ctx.KEY() != null) { - selectStatement.setOrderByPath(SQLConstant.KEY); + selectStatement.setOrderByExpr(new KeyExpression(SQLConstant.KEY)); selectStatement.setAscending(ctx.DESC() == null); } if (ctx.orderItem() != null) { @@ -1406,23 +1464,88 @@ private void parseOrderByClause(OrderByClauseContext ctx, SelectStatement select } } - private void parseOrderItem(SqlParser.OrderItemContext ctx, SelectStatement selectStatement) { - String suffix = parsePath(ctx.path()); - String orderByPath = suffix; - if (selectStatement.getSelectType() == SelectStatement.SelectStatementType.UNARY) { - UnarySelectStatement unarySelectStatement = (UnarySelectStatement) selectStatement; - String prefix = unarySelectStatement.getFromPart(0).getPrefix(); + private void parseOrderItem(OrderItemContext ctx, SelectStatement selectStatement) { + if (ctx.expression().subquery() != null) { + throw new SQLParserException("Subquery is not supported in ORDER BY columns."); + } + + UnarySelectStatement unarySelectStatement = selectStatement.getFirstUnarySelectStatement(); + Expression expr = parseExpression(ctx.expression(), unarySelectStatement, false).get(0); + if (expr instanceof BaseExpression) { + BaseExpression baseExpr = (BaseExpression) expr; + Set orderByExprSet = new HashSet<>(); + String path = baseExpr.getPathName(); + if (path.contains("*")) { + throw new SQLParserException( + String.format("ORDER BY column '%s' has '*', which is not supported.", path)); + } + if (selectStatement.getSelectType() == SelectStatement.SelectStatementType.UNARY) { + UnarySelectStatement stmt = (UnarySelectStatement) selectStatement; + String pathRemovePrefix = path; + // 删去在解析expression时加上的前缀 + if (stmt.isFromSinglePath()) { + pathRemovePrefix = path.replaceFirst(stmt.getFromPart(0).getPrefix() + "\\.", ""); + } + + for (Expression selectExpr : selectStatement.getExpressions()) { + if (StringUtils.match(path, selectExpr.getColumnName())) { // 匹配select表达式 + orderByExprSet.add(expr); + continue; + } + if (selectExpr.getAlias().equals(pathRemovePrefix)) { // 匹配select表达式别名 + orderByExprSet.add(selectExpr); + } + } + + // 匹配到了多个select表达式 + if (orderByExprSet.size() > 1) { + throw new SQLParserException(String.format("ORDER BY column '%s' is ambiguous.", path)); + } - // 如果查询语句的FROM子句只有一个部分且FROM一个前缀,则ORDER BY后的path只用写出后缀 - if (unarySelectStatement.isFromSinglePath()) { - orderByPath = prefix + SQLConstant.DOT + suffix; + // ORDER BY的表达式没有出现在SELECT子句中 + if (orderByExprSet.isEmpty()) { + selectStatement.setOrderByExpr(expr); + String originPath = selectStatement.getOriginPath(path); + if (originPath != null) { + ((UnarySelectStatement) selectStatement).addOrderByPath(originPath); + } + } else { + selectStatement.setOrderByExpr(orderByExprSet.iterator().next()); + } } + } else { + MappingType type = ExpressionUtils.getExprMappingType(expr); + if (type == MappingType.SetMapping || type == MappingType.Mapping) { + throw new SQLParserException("ORDER BY column can not use SetToSet/SetToRow functions."); + } + + // 在SELECT子句中查找相同的表达式,避免重复计算(主要是case when) + boolean foundInSelect = false; + for (Expression selectExpr : selectStatement.getExpressions()) { + if (ExpressionUtils.getExprMappingType(selectExpr) == MappingType.RowMapping + && selectExpr.equalExceptAlias(expr)) { + selectStatement.setOrderByExpr(selectExpr); + foundInSelect = true; + break; + } + } + if (!foundInSelect) { + selectStatement.setOrderByExpr(expr); + } + + // 查找需要加入到pathSet的path + List baseExpressions = + ExpressionUtils.getBaseExpressionList(Collections.singletonList(expr), false); + baseExpressions.forEach( + baseExpression -> { + String path = baseExpression.getPathName(); + String originPath = selectStatement.getOriginPath(path); + if (originPath != null) { + ((UnarySelectStatement) selectStatement).addOrderByPath(originPath); + } + }); } - if (orderByPath.contains("*")) { - throw new SQLParserException( - String.format("ORDER BY path '%s' has '*', which is not supported.", orderByPath)); - } - selectStatement.setOrderByPath(orderByPath); + selectStatement.setAscending(ctx.DESC() == null); } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/BinarySelectStatement.java b/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/BinarySelectStatement.java index 8cce659cd2..bebe41de83 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/BinarySelectStatement.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/BinarySelectStatement.java @@ -68,6 +68,11 @@ public List getExpressions() { return leftQuery.getExpressions(); } + @Override + public UnarySelectStatement getFirstUnarySelectStatement() { + return leftQuery.getFirstUnarySelectStatement(); + } + @Override public Set getPathSet() { Set pathSet = new HashSet<>(leftQuery.getPathSet()); diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/SelectStatement.java b/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/SelectStatement.java index dd51161795..9ea93f3d96 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/SelectStatement.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/SelectStatement.java @@ -83,12 +83,18 @@ public boolean isSubQuery() { public abstract Set getPathSet(); - public List getOrderByPaths() { - return orderByClause.getOrderByPaths(); + public abstract UnarySelectStatement getFirstUnarySelectStatement(); + + public String getOriginPath(String path) { + return null; + } + + public List getOrderByExpressions() { + return orderByClause.getOrderByExpressions(); } - public void setOrderByPath(String orderByPath) { - this.orderByClause.setOrderByPaths(orderByPath); + public void setOrderByExpr(Expression orderByExpr) { + this.orderByClause.setOrderByExpr(orderByExpr); } public List getAscendingList() { diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/UnarySelectStatement.java b/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/UnarySelectStatement.java index 5b60d4ee03..1b6eb39dc5 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/UnarySelectStatement.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/UnarySelectStatement.java @@ -22,7 +22,6 @@ import static cn.edu.tsinghua.iginx.sql.SQLConstant.L_PARENTHESES; import static cn.edu.tsinghua.iginx.sql.SQLConstant.R_PARENTHESES; -import cn.edu.tsinghua.iginx.engine.physical.memory.execute.utils.ExprUtils; import cn.edu.tsinghua.iginx.engine.physical.memory.execute.utils.FilterUtils; import cn.edu.tsinghua.iginx.engine.shared.expr.*; import cn.edu.tsinghua.iginx.engine.shared.function.FunctionUtils; @@ -241,12 +240,7 @@ public List getBaseExpressionList() { * @return BaseExpression列表 */ public List getBaseExpressionList(boolean exceptFunc) { - List paths = ExprUtils.getPathFromExprList(getExpressions(), exceptFunc); - List baseExpressionList = new ArrayList<>(paths.size()); - for (String path : paths) { - baseExpressionList.add(new BaseExpression(path)); - } - return baseExpressionList; + return ExpressionUtils.getBaseExpressionList(getExpressions(), exceptFunc); } public List getSequenceExpressionList() { @@ -260,6 +254,7 @@ public Set getPathSet() { pathSet.addAll(whereClause.getPathSet()); pathSet.addAll(groupByClause.getPathSet()); pathSet.addAll(havingClause.getPathSet()); + pathSet.addAll(orderByClause.getPathSet()); return pathSet; } @@ -279,6 +274,10 @@ public void addHavingPath(String path) { havingClause.addPath(path); } + public void addOrderByPath(String path) { + orderByClause.addPath(path); + } + public List getSelectSubQueryParts() { return selectClause.getSelectSubQueryParts(); } @@ -307,8 +306,8 @@ public void addWhereSubQueryPart(SubQueryFromPart whereSubQueryPart) { whereClause.addWhereSubQueryPart(whereSubQueryPart); } - public void setGroupByPath(String path) { - this.groupByClause.addGroupByPath(path); + public void setGroupByExpr(Expression expression) { + this.groupByClause.addGroupByExpression(expression); } public List getHavingSubQueryParts() { @@ -319,12 +318,12 @@ public void addHavingSubQueryPart(SubQueryFromPart havingSubQueryPart) { this.havingClause.addHavingSubQueryPart(havingSubQueryPart); } - public List getGroupByPaths() { - return groupByClause.getGroupByPaths(); + public List getGroupByExpressions() { + return groupByClause.getGroupByExpressions(); } - public void setOrderByPath(String orderByPath) { - this.orderByClause.setOrderByPaths(orderByPath); + public void setOrderByExpr(Expression orderByExpr) { + super.setOrderByExpr(orderByExpr); } public Filter getFilter() { @@ -401,6 +400,11 @@ public List getExpressions() { return new ArrayList<>(selectClause.getExpressions()); } + @Override + public UnarySelectStatement getFirstUnarySelectStatement() { + return this; + } + public void addSelectClauseExpression(Expression expression) { selectClause.addExpression(expression); } @@ -453,6 +457,11 @@ public List> getSubQueryAliasList(String alias) { public boolean needRowTransform() { for (Expression expression : getExpressions()) { + if (getQueryType() == QueryType.GroupByQuery + && groupByClause.getGroupByExpressions().stream() + .anyMatch(e -> e.equalExceptAlias(expression))) { + continue; + } if (expression.getType().equals(Expression.ExpressionType.Function)) { FuncExpression funcExpression = (FuncExpression) expression; if (FunctionUtils.isRowToRowFunction(funcExpression.getFuncName())) { @@ -652,16 +661,13 @@ public void checkQueryType() { Set typeList = new HashSet<>(); for (Expression expression : getExpressions()) { - typeList.add(getExprMappingType(expression)); + typeList.add(ExpressionUtils.getExprMappingType(expression)); } typeList.remove(null); if (hasGroupBy()) { if (typeList.contains(MappingType.Mapping)) { throw new SQLParserException("Group by can not use SetToSet functions."); - } else if (typeList.contains(MappingType.RowMapping) - && !getTargetTypeFuncExprList(MappingType.RowMapping).isEmpty()) { - throw new SQLParserException("Group by can not use RowToRow functions."); } setQueryType(QueryType.GroupByQuery); return; @@ -709,93 +715,6 @@ private static boolean isNeedNoPath(Expression expression) { return ExpressionUtils.isConstantArithmeticExpr(expression); } - /** - * 判断Expression的FuncExpression的映射类型 - * - * @param expression 给定Expression - * @return Expression的函数映射类型。若为ConstantExpression,返回null - */ - private MappingType getExprMappingType(Expression expression) { - switch (expression.getType()) { - case Constant: - case Sequence: - case FromValue: - return null; - case Base: - case Key: - case CaseWhen: - return MappingType.RowMapping; // case-when视为RowMapping函数 - case Unary: - return getExprMappingType(((UnaryExpression) expression).getExpression()); - case Bracket: - return getExprMappingType(((BracketExpression) expression).getExpression()); - case Function: - FuncExpression funcExpr = (FuncExpression) expression; - MappingType funcMappingType = FunctionUtils.getFunctionMappingType(funcExpr.getFuncName()); - Set childTypeSet = new HashSet<>(); - MappingType retType = funcMappingType; - for (Expression child : funcExpr.getExpressions()) { - MappingType childType = getExprMappingType(child); - childTypeSet.add(childType); - if (funcMappingType == MappingType.SetMapping) { - if (childType != null && childType != MappingType.RowMapping) { - throw new SQLParserException( - "SetToRow functions can not be nested with SetToSet/SetToRow functions."); - } - } else if (funcMappingType == MappingType.Mapping) { - if (childType != null && childType != MappingType.RowMapping) { - throw new SQLParserException( - "SetToSet functions can not be nested with SetToSet/SetToRow functions."); - } - } else { - if (childType != null) { - retType = childType; - } - } - } - childTypeSet.remove(null); - if (childTypeSet.size() > 1) { - throw new SQLParserException( - "SetToSet/SetToRow/RowToRow functions can not be mixed in function params."); - } - return retType; - case Binary: - BinaryExpression binaryExpr = (BinaryExpression) expression; - MappingType leftType = getExprMappingType(binaryExpr.getLeftExpression()); - MappingType rightType = getExprMappingType(binaryExpr.getRightExpression()); - if (leftType != null && rightType != null) { - if (leftType != rightType) { - throw new SQLParserException( - "SetToSet/SetToRow/RowToRow functions can not be mixed in BinaryExpression."); - } - return leftType; - } - if (leftType == null) { - return rightType; - } - return leftType; - case Multiple: - MultipleExpression multipleExpr = (MultipleExpression) expression; - Set typeSet = new HashSet<>(); - for (Expression child : multipleExpr.getChildren()) { - MappingType childType = getExprMappingType(child); - if (childType != null) { - typeSet.add(childType); - } - } - if (typeSet.size() == 1) { - return typeSet.iterator().next(); - } else if (typeSet.size() > 1) { - throw new SQLParserException( - "SetToSet/SetToRow/RowToRow functions can not be mixed in MultipleExpression."); - } else { - return null; - } - default: - throw new SQLParserException("Unknown expression type: " + expression.getType()); - } - } - public enum QueryType { Unknown, SimpleQuery, diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/subclause/GroupByClause.java b/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/subclause/GroupByClause.java index 017c1b82de..07ab9e17ae 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/subclause/GroupByClause.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/subclause/GroupByClause.java @@ -18,6 +18,7 @@ package cn.edu.tsinghua.iginx.sql.statement.select.subclause; +import cn.edu.tsinghua.iginx.engine.shared.expr.Expression; import cn.edu.tsinghua.iginx.sql.statement.select.UnarySelectStatement.QueryType; import java.util.ArrayList; import java.util.HashSet; @@ -28,26 +29,26 @@ public class GroupByClause { private boolean hasDownsample; private boolean hasGroupBy; private QueryType queryType; - private final List groupByPaths; + private final List groupByExpressions; private final Set pathSet; private long precision; private long slideDistance; public GroupByClause() { - groupByPaths = new ArrayList<>(); + groupByExpressions = new ArrayList<>(); pathSet = new HashSet<>(); hasDownsample = false; hasGroupBy = false; queryType = QueryType.Unknown; } - public void addGroupByPath(String path) { - groupByPaths.add(path); + public void addGroupByExpression(Expression expression) { + groupByExpressions.add(expression); hasGroupBy = true; } - public List getGroupByPaths() { - return groupByPaths; + public List getGroupByExpressions() { + return groupByExpressions; } public boolean hasDownsample() { diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/subclause/OrderByClause.java b/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/subclause/OrderByClause.java index 585f75a831..f2759908eb 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/subclause/OrderByClause.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/sql/statement/select/subclause/OrderByClause.java @@ -18,25 +18,31 @@ package cn.edu.tsinghua.iginx.sql.statement.select.subclause; +import cn.edu.tsinghua.iginx.engine.shared.expr.Expression; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; public class OrderByClause { - private final List orderByPaths; + private final List orderByExpressions; private final List ascendingList; + private final Set pathSet; - public OrderByClause(List orderByPaths, List ascendingList) { - this.orderByPaths = orderByPaths; + public OrderByClause(List orderByExpressions, List ascendingList) { + this.orderByExpressions = orderByExpressions; this.ascendingList = ascendingList; + this.pathSet = new HashSet<>(); } public OrderByClause() { - this.orderByPaths = new ArrayList<>(); + this.orderByExpressions = new ArrayList<>(); this.ascendingList = new ArrayList<>(); + this.pathSet = new HashSet<>(); } - public List getOrderByPaths() { - return orderByPaths; + public List getOrderByExpressions() { + return orderByExpressions; } public List getAscendingList() { @@ -47,7 +53,15 @@ public void setAscendingList(boolean ascending) { this.ascendingList.add(ascending); } - public void setOrderByPaths(String orderByPath) { - this.orderByPaths.add(orderByPath); + public void setOrderByExpr(Expression orderByExpr) { + this.orderByExpressions.add(orderByExpr); + } + + public Set getPathSet() { + return pathSet; + } + + public void addPath(String path) { + pathSet.add(path); } } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/sql/utils/ExpressionUtils.java b/core/src/main/java/cn/edu/tsinghua/iginx/sql/utils/ExpressionUtils.java index 899bf55126..d220f91159 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/sql/utils/ExpressionUtils.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/sql/utils/ExpressionUtils.java @@ -18,6 +18,8 @@ package cn.edu.tsinghua.iginx.sql.utils; +import cn.edu.tsinghua.iginx.engine.physical.memory.execute.utils.ExprUtils; +import cn.edu.tsinghua.iginx.engine.shared.expr.BaseExpression; import cn.edu.tsinghua.iginx.engine.shared.expr.BinaryExpression; import cn.edu.tsinghua.iginx.engine.shared.expr.BracketExpression; import cn.edu.tsinghua.iginx.engine.shared.expr.Expression; @@ -25,6 +27,13 @@ import cn.edu.tsinghua.iginx.engine.shared.expr.MultipleExpression; import cn.edu.tsinghua.iginx.engine.shared.expr.Operator; import cn.edu.tsinghua.iginx.engine.shared.expr.UnaryExpression; +import cn.edu.tsinghua.iginx.engine.shared.function.FunctionUtils; +import cn.edu.tsinghua.iginx.engine.shared.function.MappingType; +import cn.edu.tsinghua.iginx.sql.exception.SQLParserException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; public class ExpressionUtils { @@ -73,4 +82,108 @@ public static String transformToBaseExpr(Expression expression) { return null; } } + + /** + * 从expressions中获取所有的BaseExpression + * + * @param expressions Expression列表 + * @param exceptFunc 是否不包括FuncExpression参数中的BaseExpression + * @return BaseExpression列表 + */ + public static List getBaseExpressionList( + List expressions, boolean exceptFunc) { + List paths = ExprUtils.getPathFromExprList(expressions, exceptFunc); + List baseExpressionList = new ArrayList<>(paths.size()); + for (String path : paths) { + baseExpressionList.add(new BaseExpression(path)); + } + return baseExpressionList; + } + + /** + * 判断Expression的FuncExpression的映射类型 + * + * @param expression 给定Expression + * @return Expression的函数映射类型。若为ConstantExpression,返回null + */ + public static MappingType getExprMappingType(Expression expression) { + switch (expression.getType()) { + case Constant: + case Sequence: + case FromValue: + return null; + case Base: + case Key: + case CaseWhen: + return MappingType.RowMapping; // case-when视为RowMapping函数 + case Unary: + return getExprMappingType(((UnaryExpression) expression).getExpression()); + case Bracket: + return getExprMappingType(((BracketExpression) expression).getExpression()); + case Function: + FuncExpression funcExpr = (FuncExpression) expression; + MappingType funcMappingType = FunctionUtils.getFunctionMappingType(funcExpr.getFuncName()); + Set childTypeSet = new HashSet<>(); + MappingType retType = funcMappingType; + for (Expression child : funcExpr.getExpressions()) { + MappingType childType = getExprMappingType(child); + childTypeSet.add(childType); + if (funcMappingType == MappingType.SetMapping) { + if (childType != null && childType != MappingType.RowMapping) { + throw new SQLParserException( + "SetToRow functions can not be nested with SetToSet/SetToRow functions."); + } + } else if (funcMappingType == MappingType.Mapping) { + if (childType != null && childType != MappingType.RowMapping) { + throw new SQLParserException( + "SetToSet functions can not be nested with SetToSet/SetToRow functions."); + } + } else { + if (childType != null) { + retType = childType; + } + } + } + childTypeSet.remove(null); + if (childTypeSet.size() > 1) { + throw new SQLParserException( + "SetToSet/SetToRow/RowToRow functions can not be mixed in function params."); + } + return retType; + case Binary: + BinaryExpression binaryExpr = (BinaryExpression) expression; + MappingType leftType = getExprMappingType(binaryExpr.getLeftExpression()); + MappingType rightType = getExprMappingType(binaryExpr.getRightExpression()); + if (leftType != null && rightType != null) { + if (leftType != rightType) { + throw new SQLParserException( + "SetToSet/SetToRow/RowToRow functions can not be mixed in BinaryExpression."); + } + return leftType; + } + if (leftType == null) { + return rightType; + } + return leftType; + case Multiple: + MultipleExpression multipleExpr = (MultipleExpression) expression; + Set typeSet = new HashSet<>(); + for (Expression child : multipleExpr.getChildren()) { + MappingType childType = getExprMappingType(child); + if (childType != null) { + typeSet.add(childType); + } + } + if (typeSet.size() == 1) { + return typeSet.iterator().next(); + } else if (typeSet.size() > 1) { + throw new SQLParserException( + "SetToSet/SetToRow/RowToRow functions can not be mixed in MultipleExpression."); + } else { + return null; + } + default: + throw new SQLParserException("Unknown expression type: " + expression.getType()); + } + } } diff --git a/core/src/test/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/AbstractOperatorMemoryExecutorTest.java b/core/src/test/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/AbstractOperatorMemoryExecutorTest.java index 2c38628fa6..111e5b7d5e 100644 --- a/core/src/test/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/AbstractOperatorMemoryExecutorTest.java +++ b/core/src/test/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/AbstractOperatorMemoryExecutorTest.java @@ -26,7 +26,6 @@ import cn.edu.tsinghua.iginx.engine.physical.exception.InvalidOperatorParameterException; import cn.edu.tsinghua.iginx.engine.physical.exception.PhysicalException; -import cn.edu.tsinghua.iginx.engine.shared.Constants; import cn.edu.tsinghua.iginx.engine.shared.KeyRange; import cn.edu.tsinghua.iginx.engine.shared.data.Value; import cn.edu.tsinghua.iginx.engine.shared.data.read.Field; @@ -34,6 +33,7 @@ import cn.edu.tsinghua.iginx.engine.shared.data.read.Row; import cn.edu.tsinghua.iginx.engine.shared.data.read.RowStream; import cn.edu.tsinghua.iginx.engine.shared.expr.BaseExpression; +import cn.edu.tsinghua.iginx.engine.shared.expr.KeyExpression; import cn.edu.tsinghua.iginx.engine.shared.function.FunctionCall; import cn.edu.tsinghua.iginx.engine.shared.function.FunctionParams; import cn.edu.tsinghua.iginx.engine.shared.function.system.Avg; @@ -65,6 +65,7 @@ import cn.edu.tsinghua.iginx.engine.shared.operator.type.JoinAlgType; import cn.edu.tsinghua.iginx.engine.shared.operator.type.OuterJoinType; import cn.edu.tsinghua.iginx.engine.shared.source.EmptySource; +import cn.edu.tsinghua.iginx.sql.SQLConstant; import cn.edu.tsinghua.iginx.thrift.DataType; import java.util.ArrayList; import java.util.Arrays; @@ -2176,7 +2177,7 @@ public void testSortByTimeAsc() throws PhysicalException { Sort sort = new Sort( EmptySource.EMPTY_SOURCE, - Collections.singletonList(Constants.KEY), + Collections.singletonList(new KeyExpression(SQLConstant.KEY)), Collections.singletonList(Sort.SortType.ASC)); RowStream stream = getExecutor().executeUnaryOperator(sort, table, null); assertEquals(table.getHeader(), stream.getHeader()); @@ -2197,7 +2198,7 @@ public void testSortByTimeDesc() throws PhysicalException { Sort sort = new Sort( EmptySource.EMPTY_SOURCE, - Collections.singletonList(Constants.KEY), + Collections.singletonList(new KeyExpression(SQLConstant.KEY)), Collections.singletonList(Sort.SortType.DESC)); RowStream stream = getExecutor().executeUnaryOperator(sort, copyTable, null); assertEquals(table.getHeader(), stream.getHeader()); diff --git a/core/src/test/java/cn/edu/tsinghua/iginx/sql/ParseTest.java b/core/src/test/java/cn/edu/tsinghua/iginx/sql/ParseTest.java index dfc0f3fbc1..d8be206a8d 100644 --- a/core/src/test/java/cn/edu/tsinghua/iginx/sql/ParseTest.java +++ b/core/src/test/java/cn/edu/tsinghua/iginx/sql/ParseTest.java @@ -22,7 +22,9 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import cn.edu.tsinghua.iginx.engine.shared.expr.BaseExpression; import cn.edu.tsinghua.iginx.engine.shared.expr.FuncExpression; +import cn.edu.tsinghua.iginx.engine.shared.expr.KeyExpression; import cn.edu.tsinghua.iginx.engine.shared.function.MappingType; import cn.edu.tsinghua.iginx.engine.shared.operator.filter.Op; import cn.edu.tsinghua.iginx.engine.shared.operator.filter.PathFilter; @@ -144,12 +146,19 @@ public void testParseSpecialClause() { String orderBy = "SELECT a FROM test ORDER BY KEY"; statement = (UnarySelectStatement) TestUtils.buildStatement(orderBy); - assertEquals(Collections.singletonList(SQLConstant.KEY), statement.getOrderByPaths()); + assertEquals(1, statement.getOrderByExpressions().size()); + assertTrue( + statement + .getOrderByExpressions() + .get(0) + .equalExceptAlias(new KeyExpression(SQLConstant.KEY))); assertTrue(statement.getAscendingList().get(0)); String orderByAndLimit = "SELECT a FROM test ORDER BY a DESC LIMIT 10 OFFSET 5;"; statement = (UnarySelectStatement) TestUtils.buildStatement(orderByAndLimit); - assertEquals(Collections.singletonList("test.a"), statement.getOrderByPaths()); + assertEquals(1, statement.getOrderByExpressions().size()); + assertTrue( + statement.getOrderByExpressions().get(0).equalExceptAlias(new BaseExpression("test.a"))); assertFalse(statement.getAscendingList().get(0)); assertEquals(5, statement.getOffset()); assertEquals(10, statement.getLimit()); diff --git a/optimizer/src/main/java/cn/edu/tsinghua/iginx/logical/optimizer/rules/ColumnPruningRule.java b/optimizer/src/main/java/cn/edu/tsinghua/iginx/logical/optimizer/rules/ColumnPruningRule.java index 0d01fac51c..d1109a803a 100644 --- a/optimizer/src/main/java/cn/edu/tsinghua/iginx/logical/optimizer/rules/ColumnPruningRule.java +++ b/optimizer/src/main/java/cn/edu/tsinghua/iginx/logical/optimizer/rules/ColumnPruningRule.java @@ -18,13 +18,19 @@ package cn.edu.tsinghua.iginx.logical.optimizer.rules; +import static cn.edu.tsinghua.iginx.engine.shared.function.system.ArithmeticExpr.ARITHMETIC_EXPR; + import cn.edu.tsinghua.iginx.engine.logical.utils.OperatorUtils; import cn.edu.tsinghua.iginx.engine.logical.utils.PathUtils; import cn.edu.tsinghua.iginx.engine.physical.memory.execute.utils.ExprUtils; import cn.edu.tsinghua.iginx.engine.physical.memory.execute.utils.FilterUtils; import cn.edu.tsinghua.iginx.engine.shared.Constants; +import cn.edu.tsinghua.iginx.engine.shared.expr.Expression; +import cn.edu.tsinghua.iginx.engine.shared.expr.KeyExpression; import cn.edu.tsinghua.iginx.engine.shared.function.FunctionCall; +import cn.edu.tsinghua.iginx.engine.shared.function.FunctionParams; import cn.edu.tsinghua.iginx.engine.shared.function.FunctionUtils; +import cn.edu.tsinghua.iginx.engine.shared.function.manager.FunctionManager; import cn.edu.tsinghua.iginx.engine.shared.function.system.ArithmeticExpr; import cn.edu.tsinghua.iginx.engine.shared.function.system.First; import cn.edu.tsinghua.iginx.engine.shared.function.system.Last; @@ -51,6 +57,8 @@ public class ColumnPruningRule extends Rule { private static final Logger LOGGER = LoggerFactory.getLogger(ColumnPruningRule.class); + private static final FunctionManager functionManager = FunctionManager.getInstance(); + public ColumnPruningRule() { /* * we want to match the topology like: @@ -124,16 +132,23 @@ private void collectColumns( } else if (operator.getType() == OperatorType.GroupBy) { GroupBy groupBy = (GroupBy) operator; - newColumnList = groupBy.getGroupByCols(); - functionCallList = groupBy.getFunctionCallList(); + functionCallList = new ArrayList<>(groupBy.getFunctionCallList()); + for (Expression groupByExpr : groupBy.getGroupByExpressions()) { + functionCallList.add( + new FunctionCall( + functionManager.getFunction(ARITHMETIC_EXPR), new FunctionParams(groupByExpr))); + } } else if (operator.getType() == OperatorType.Downsample) { Downsample downsample = (Downsample) operator; functionCallList = downsample.getFunctionCallList(); } else if (operator.getType() == OperatorType.Sort) { Sort sort = (Sort) operator; - for (String column : sort.getSortByCols()) { - if (!column.equalsIgnoreCase(Constants.KEY)) { - columns.add(column); + functionCallList = new ArrayList<>(); + for (Expression sortByExpr : sort.getSortByExpressions()) { + if (!(sortByExpr instanceof KeyExpression)) { + functionCallList.add( + new FunctionCall( + functionManager.getFunction(ARITHMETIC_EXPR), new FunctionParams(sortByExpr))); } } } else if (operator.getType() == OperatorType.AddSchemaPrefix) { diff --git a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java index 7bff1fb74e..0b4516f75e 100644 --- a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java +++ b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java @@ -2939,6 +2939,64 @@ public void testGroupByWithCaseWhen() { executor.executeAndCompare(query, expected); } + @Test + public void testGroupByAndOrderByExpr() { + String insert = + "INSERT INTO student(key, s_id, name, sex, age) VALUES " + + "(0, 1, \"Alan\", 1, 16), (1, 2, \"Bob\", 1, 14), (2, 3, \"Candy\", 0, 17), " + + "(3, 4, \"Alice\", 0, 22), (4, 5, \"Jack\", 1, 36), (5, 6, \"Tom\", 1, 20);"; + executor.execute(insert); + insert = + "INSERT INTO math(key, s_id, score) VALUES (0, 1, 82), (1, 2, 58), (2, 3, 54), (3, 4, 92), (4, 5, 78), (5, 6, 98);"; + executor.execute(insert); + + // use alias in GROUP BY and ORDER BY + String statement = + "SELECT avg(math.score) as avg_score, CASE student.sex WHEN 1 THEN 'Male' WHEN 0 THEN 'Female' ELSE 'Unknown' END AS strSex\n" + + "FROM student JOIN math ON student.s_id = math.s_id\n" + + "GROUP BY strSex ORDER BY strSex;"; + String expected = + "ResultSets:\n" + + "+---------+------+\n" + + "|avg_score|strSex|\n" + + "+---------+------+\n" + + "| 73.0|Female|\n" + + "| 79.0| Male|\n" + + "+---------+------+\n" + + "Total line number = 2\n"; + executor.executeAndCompare(statement, expected); + + // don't use alias in GROUP BY and ORDER BY + statement = + "SELECT avg(math.score) as avg_score, CASE student.sex WHEN 1 THEN 'Male' WHEN 0 THEN 'Female' ELSE 'Unknown' END AS strSex\n" + + "FROM student JOIN math ON student.s_id = math.s_id\n" + + "GROUP BY CASE student.sex WHEN 1 THEN 'Male' WHEN 0 THEN 'Female' ELSE 'Unknown' END\n" + + "ORDER BY CASE student.sex WHEN 1 THEN 'Male' WHEN 0 THEN 'Female' ELSE 'Unknown' END DESC;"; + expected = + "ResultSets:\n" + + "+---------+------+\n" + + "|avg_score|strSex|\n" + + "+---------+------+\n" + + "| 79.0| Male|\n" + + "| 73.0|Female|\n" + + "+---------+------+\n" + + "Total line number = 2\n"; + executor.executeAndCompare(statement, expected); + + statement = "SELECT s_id % 3 AS id, sum(score) FROM math GROUP BY id ORDER BY id;"; + expected = + "ResultSets:\n" + + "+--+---------------+\n" + + "|id|sum(math.score)|\n" + + "+--+---------------+\n" + + "| 0| 152|\n" + + "| 1| 174|\n" + + "| 2| 136|\n" + + "+--+---------------+\n" + + "Total line number = 3\n"; + executor.executeAndCompare(statement, expected); + } + @Test public void testJoinWithGroupBy() { String insert = diff --git a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/udf/UDFIT.java b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/udf/UDFIT.java index 3a53f49aed..43b86bce65 100644 --- a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/udf/UDFIT.java +++ b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/udf/UDFIT.java @@ -677,6 +677,64 @@ public void testColumnExpand() { compareResult(expected, ret.getResultInString(false, "")); } + @Test + public void testUDFGroupByAndOrderByExpr() { + String insert = + "INSERT INTO test(key, s1, s2) VALUES (1, 2, 3), (2, 3, 1), (3, 2, 3), (4, 3, 7), (5, 3, 6), (6, 0, 4);"; + tool.execute(insert); + + List cosTestS1AfterGroupByExpectedValues = + Arrays.asList(-0.9899924966004454, -0.4161468365471424, 1.0); + List sumTestS2AfterGroupByExpectedValues = Arrays.asList(14L, 6L, 4L); + + String query = "SELECT cos(s1), sum(s2) FROM test GROUP BY cos(s1) ORDER BY cos(s1);"; + SessionExecuteSqlResult ret = tool.execute(query); + compareResult(2, ret.getPaths().size()); + compareResult("cos(test.s1)", ret.getPaths().get(0)); + compareResult("sum(test.s2)", ret.getPaths().get(1)); + for (int i = 0; i < ret.getValues().size(); i++) { + compareResult(2, ret.getValues().get(i).size()); + double expectedCosS1 = cosTestS1AfterGroupByExpectedValues.get(i); + double actualCosS1 = (double) ret.getValues().get(i).get(0); + compareResult(expectedCosS1, actualCosS1, delta); + long expectedSumS2 = sumTestS2AfterGroupByExpectedValues.get(i); + long actualSumS2 = (long) ret.getValues().get(i).get(1); + assertEquals(expectedSumS2, actualSumS2); + } + + query = "SELECT cos(s1) AS a, sum(s2) FROM test GROUP BY a ORDER BY a;"; + ret = tool.execute(query); + compareResult(2, ret.getPaths().size()); + compareResult("a", ret.getPaths().get(0)); + compareResult("sum(test.s2)", ret.getPaths().get(1)); + for (int i = 0; i < ret.getValues().size(); i++) { + compareResult(2, ret.getValues().get(i).size()); + double expectedCosS1 = cosTestS1AfterGroupByExpectedValues.get(i); + double actualCosS1 = (double) ret.getValues().get(i).get(0); + compareResult(expectedCosS1, actualCosS1, delta); + long expectedSumS2 = sumTestS2AfterGroupByExpectedValues.get(i); + long actualSumS2 = (long) ret.getValues().get(i).get(1); + assertEquals(expectedSumS2, actualSumS2); + } + + query = "SELECT s1, s2 FROM test ORDER BY cos(s1);"; + ret = tool.execute(query); + String expected = + "ResultSets:\n" + + "+---+-------+-------+\n" + + "|key|test.s1|test.s2|\n" + + "+---+-------+-------+\n" + + "| 2| 3| 1|\n" + + "| 4| 3| 7|\n" + + "| 5| 3| 6|\n" + + "| 1| 2| 3|\n" + + "| 3| 2| 3|\n" + + "| 6| 0| 4|\n" + + "+---+-------+-------+\n" + + "Total line number = 6\n"; + compareResult(expected, ret.getResultInString(false, "")); + } + @Test public void testUDFWithArgs() { String insert = From 92071479a0e25221204f301f89c0ca901d0003b6 Mon Sep 17 00:00:00 2001 From: jzl18thu Date: Wed, 16 Oct 2024 08:12:12 +0800 Subject: [PATCH 2/6] debug --- .../engine/physical/memory/execute/utils/HeaderUtils.java | 4 +--- .../edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java | 2 ++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/utils/HeaderUtils.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/utils/HeaderUtils.java index dfd68ba422..4886c4616e 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/utils/HeaderUtils.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/physical/memory/execute/utils/HeaderUtils.java @@ -51,8 +51,6 @@ public class HeaderUtils { - private static final FunctionManager functionManager = FunctionManager.getInstance(); - public static Header constructNewHead(Header header, String markColumn) { List fields = new ArrayList<>(header.getFields()); fields.add(new Field(markColumn, BOOLEAN)); @@ -386,7 +384,7 @@ public static RowTransform checkSortHeader(Header header, Sort sort) { private static RowTransform appendArithExpressions(Header header, List expressions) { List functionCallList = new ArrayList<>(); - Function function = functionManager.getFunction(ARITHMETIC_EXPR); + Function function = FunctionManager.getInstance().getFunction(ARITHMETIC_EXPR); for (Field field : header.getFields()) { functionCallList.add( new FunctionCall(function, new FunctionParams(new BaseExpression(field.getName())))); diff --git a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java index 0b4516f75e..4b53f82d28 100644 --- a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java +++ b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java @@ -4665,6 +4665,7 @@ public void testSelectSubQuery() { statement = "SELECT a, (SELECT AVG(a) AS a1 FROM test.b GROUP BY d HAVING avg(test.b.a) > 2) * (SELECT AVG(a) AS a2 FROM test.b) FROM test.a;"; + LOGGER.debug(executor.execute("EXPLAIN " + statement)); expected = "ResultSets:\n" + "+---+--------+-------+\n" @@ -8294,6 +8295,7 @@ public void testDistinctEliminate() { statement = "SELECT max(distinct s1) FROM us.d1 GROUP BY s2;"; executor.execute(closeRule); assertTrue(executor.execute("EXPLAIN " + statement).contains("isDistinct: true")); + LOGGER.debug(executor.execute("EXPLAIN " + statement)); closeResult = executor.execute(statement); executor.execute(openRule); assertTrue(executor.execute("EXPLAIN " + statement).contains("isDistinct: false")); From bcfc195f13cc9bc725c31e1884357c6f3cc90e79 Mon Sep 17 00:00:00 2001 From: jzl18thu Date: Fri, 18 Oct 2024 09:58:53 +0800 Subject: [PATCH 3/6] fix --- .../main/java/cn/edu/tsinghua/iginx/sql/IginXSqlVisitor.java | 2 +- .../edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/sql/IginXSqlVisitor.java b/core/src/main/java/cn/edu/tsinghua/iginx/sql/IginXSqlVisitor.java index dcbaf7e5a3..06a9739256 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/sql/IginXSqlVisitor.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/sql/IginXSqlVisitor.java @@ -1369,7 +1369,7 @@ private void parseGroupByClause(GroupByClauseContext ctx, UnarySelectStatement s // GROUP BY的表达式没有出现在SELECT子句中 if (groupByExprSet.isEmpty()) { selectStatement.setGroupByExpr(expr); - String originPath = selectStatement.getOriginPath(path); + String originPath = selectStatement.getOriginPath(baseExpr.getPathName()); if (originPath != null) { selectStatement.addGroupByPath(originPath); } diff --git a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java index 4b53f82d28..0b4516f75e 100644 --- a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java +++ b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java @@ -4665,7 +4665,6 @@ public void testSelectSubQuery() { statement = "SELECT a, (SELECT AVG(a) AS a1 FROM test.b GROUP BY d HAVING avg(test.b.a) > 2) * (SELECT AVG(a) AS a2 FROM test.b) FROM test.a;"; - LOGGER.debug(executor.execute("EXPLAIN " + statement)); expected = "ResultSets:\n" + "+---+--------+-------+\n" @@ -8295,7 +8294,6 @@ public void testDistinctEliminate() { statement = "SELECT max(distinct s1) FROM us.d1 GROUP BY s2;"; executor.execute(closeRule); assertTrue(executor.execute("EXPLAIN " + statement).contains("isDistinct: true")); - LOGGER.debug(executor.execute("EXPLAIN " + statement)); closeResult = executor.execute(statement); executor.execute(openRule); assertTrue(executor.execute("EXPLAIN " + statement).contains("isDistinct: false")); From 4da6f20fadae5dd2f71b281f032ee28fa89518c2 Mon Sep 17 00:00:00 2001 From: jzl18thu Date: Sat, 19 Oct 2024 12:01:34 +0800 Subject: [PATCH 4/6] adjust parser --- .../antlr4/cn/edu/tsinghua/iginx/sql/Sql.g4 | 9 +- .../tsinghua/iginx/sql/IginXSqlVisitor.java | 213 +++++++++--------- 2 files changed, 117 insertions(+), 105 deletions(-) diff --git a/antlr/src/main/antlr4/cn/edu/tsinghua/iginx/sql/Sql.g4 b/antlr/src/main/antlr4/cn/edu/tsinghua/iginx/sql/Sql.g4 index 74ad2f79e0..d2a4f59b95 100644 --- a/antlr/src/main/antlr4/cn/edu/tsinghua/iginx/sql/Sql.g4 +++ b/antlr/src/main/antlr4/cn/edu/tsinghua/iginx/sql/Sql.g4 @@ -301,7 +301,12 @@ specialClause ; groupByClause - : GROUP BY expression (COMMA expression)* + : GROUP BY groupByItem (COMMA groupByItem)* + ; + +groupByItem + : path + | expression ; havingClause @@ -313,7 +318,7 @@ orderByClause ; orderItem - : expression (DESC | ASC)? + : (path | expression) (DESC | ASC)? ; downsampleClause diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/sql/IginXSqlVisitor.java b/core/src/main/java/cn/edu/tsinghua/iginx/sql/IginXSqlVisitor.java index 06a9739256..40a6c12a2c 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/sql/IginXSqlVisitor.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/sql/IginXSqlVisitor.java @@ -89,6 +89,7 @@ import cn.edu.tsinghua.iginx.sql.SqlParser.FromClauseContext; import cn.edu.tsinghua.iginx.sql.SqlParser.FunctionContext; import cn.edu.tsinghua.iginx.sql.SqlParser.GroupByClauseContext; +import cn.edu.tsinghua.iginx.sql.SqlParser.GroupByItemContext; import cn.edu.tsinghua.iginx.sql.SqlParser.ImportFileClauseContext; import cn.edu.tsinghua.iginx.sql.SqlParser.InsertFromFileStatementContext; import cn.edu.tsinghua.iginx.sql.SqlParser.InsertFullPathSpecContext; @@ -1331,71 +1332,12 @@ private void parseDownsampleClause( private void parseGroupByClause(GroupByClauseContext ctx, UnarySelectStatement selectStatement) { selectStatement.setHasGroupBy(true); - for (ExpressionContext exprCtx : ctx.expression()) { - if (exprCtx.subquery() != null) { - throw new SQLParserException("Subquery is not supported in GROUP BY columns."); - } - - Expression expr = parseExpression(exprCtx, selectStatement, false).get(0); - - if (expr instanceof BaseExpression) { - BaseExpression baseExpr = (BaseExpression) expr; - Set groupByExprSet = new HashSet<>(); - String path = baseExpr.getPathName(); - if (path.contains("*")) { - throw new SQLParserException( - String.format("GROUP BY column '%s' has '*', which is not supported.", path)); - } - // 删去在解析expression时加上的前缀 - if (selectStatement.isFromSinglePath()) { - path = path.replaceFirst(selectStatement.getFromPart(0).getPrefix() + "\\.", ""); - } - - for (Expression selectExpr : selectStatement.getExpressions()) { - if (selectExpr.equalExceptAlias(baseExpr)) { // 匹配select表达式 - groupByExprSet.add(selectExpr); - continue; - } - if (selectExpr.getAlias().equals(path)) { // 匹配select表达式别名 - groupByExprSet.add(selectExpr); - } - } - - // 匹配到了多个select表达式 - if (groupByExprSet.size() > 1) { - throw new SQLParserException(String.format("GROUP BY column '%s' is ambiguous.", path)); - } - - // GROUP BY的表达式没有出现在SELECT子句中 - if (groupByExprSet.isEmpty()) { - selectStatement.setGroupByExpr(expr); - String originPath = selectStatement.getOriginPath(baseExpr.getPathName()); - if (originPath != null) { - selectStatement.addGroupByPath(originPath); - } - } else { - selectStatement.setGroupByExpr(groupByExprSet.iterator().next()); - } - } else { - MappingType type = ExpressionUtils.getExprMappingType(expr); - if (type == MappingType.SetMapping || type == MappingType.Mapping) { - throw new SQLParserException("GROUP BY column can not use SetToSet/SetToRow functions."); - } - selectStatement.setGroupByExpr(expr); - - List baseExpressions = - ExpressionUtils.getBaseExpressionList(Collections.singletonList(expr), false); - baseExpressions.forEach( - baseExpression -> { - String path = baseExpression.getPathName(); - String originPath = selectStatement.getOriginPath(path); - if (originPath != null) { - selectStatement.addGroupByPath(originPath); - } - }); - } + for (GroupByItemContext groupByItemContext : ctx.groupByItem()) { + Expression groupByExpr = parseGroupByItem(groupByItemContext, selectStatement); + selectStatement.setGroupByExpr(groupByExpr); } + // 检查SELECT子句中未被聚合的列是否在GROUP BY子句中出现 selectStatement .getExpressions() .forEach( @@ -1420,6 +1362,73 @@ private void parseGroupByClause(GroupByClauseContext ctx, UnarySelectStatement s }); } + private Expression parseGroupByItem( + GroupByItemContext ctx, UnarySelectStatement selectStatement) { + if (ctx.path() != null) { + String path = parsePath(ctx.path()); + if (path.contains("*")) { + throw new SQLParserException( + String.format("GROUP BY column '%s' has '*', which is not supported.", path)); + } + + String fullPath = path; + // 如果查询语句的FROM子句只有一个部分且FROM一个前缀,则GROUP BY后的path只用写出后缀 + if (selectStatement.isFromSinglePath()) { + fullPath = selectStatement.getFromPart(0).getPrefix() + SQLConstant.DOT + path; + } + + Set groupByExprSet = new HashSet<>(); + for (Expression selectExpr : selectStatement.getExpressions()) { + if (selectExpr.getColumnName().equals(fullPath)) { // 直接匹配select表达式 + groupByExprSet.add(selectExpr); + continue; + } + if (selectExpr.getAlias().equals(path)) { // 根据别名匹配select表达式 + groupByExprSet.add(selectExpr); + } + } + + // 匹配到了多个select表达式 + if (groupByExprSet.size() > 1) { + throw new SQLParserException(String.format("GROUP BY column '%s' is ambiguous.", path)); + } + + String originPath = selectStatement.getOriginPath(fullPath); + if (originPath != null) { + selectStatement.addGroupByPath(originPath); + } + + // GROUP BY的列没有出现在SELECT子句中 + if (groupByExprSet.isEmpty()) { + return new BaseExpression(fullPath); + } else { + return groupByExprSet.iterator().next(); + } + } else { + assert ctx.expression() != null; + if (ctx.expression().subquery() != null) { + throw new SQLParserException("Subquery is not supported in GROUP BY columns."); + } + Expression expr = parseExpression(ctx.expression(), selectStatement, false).get(0); + MappingType type = ExpressionUtils.getExprMappingType(expr); + if (type == MappingType.SetMapping || type == MappingType.Mapping) { + throw new SQLParserException("GROUP BY column can not use SetToSet/SetToRow functions."); + } + + // 查找需要加入到pathSet的path + List baseExpressions = + ExpressionUtils.getBaseExpressionList(Collections.singletonList(expr), false); + baseExpressions.forEach( + baseExpression -> { + String originPath = selectStatement.getOriginPath(baseExpression.getPathName()); + if (originPath != null) { + selectStatement.addGroupByPath(originPath); + } + }); + return expr; + } + } + // like standard SQL, limit N, M means limit M offset N private void parseLimitClause(LimitClauseContext ctx, SelectStatement selectStatement) { Pair p = getLimitAndOffsetFromCtx(ctx); @@ -1465,61 +1474,60 @@ private void parseOrderByClause(OrderByClauseContext ctx, SelectStatement select } private void parseOrderItem(OrderItemContext ctx, SelectStatement selectStatement) { - if (ctx.expression().subquery() != null) { - throw new SQLParserException("Subquery is not supported in ORDER BY columns."); - } + UnarySelectStatement statement = selectStatement.getFirstUnarySelectStatement(); + if (ctx.path() != null) { + String path = parsePath(ctx.path()); - UnarySelectStatement unarySelectStatement = selectStatement.getFirstUnarySelectStatement(); - Expression expr = parseExpression(ctx.expression(), unarySelectStatement, false).get(0); - if (expr instanceof BaseExpression) { - BaseExpression baseExpr = (BaseExpression) expr; Set orderByExprSet = new HashSet<>(); - String path = baseExpr.getPathName(); if (path.contains("*")) { throw new SQLParserException( String.format("ORDER BY column '%s' has '*', which is not supported.", path)); } - if (selectStatement.getSelectType() == SelectStatement.SelectStatementType.UNARY) { - UnarySelectStatement stmt = (UnarySelectStatement) selectStatement; - String pathRemovePrefix = path; - // 删去在解析expression时加上的前缀 - if (stmt.isFromSinglePath()) { - pathRemovePrefix = path.replaceFirst(stmt.getFromPart(0).getPrefix() + "\\.", ""); - } - for (Expression selectExpr : selectStatement.getExpressions()) { - if (StringUtils.match(path, selectExpr.getColumnName())) { // 匹配select表达式 - orderByExprSet.add(expr); - continue; - } - if (selectExpr.getAlias().equals(pathRemovePrefix)) { // 匹配select表达式别名 - orderByExprSet.add(selectExpr); - } - } + String fullPath = path; + // 如果查询语句的FROM子句只有一个部分且FROM一个前缀,则ORDER BY后的path只用写出后缀 + if (statement.isFromSinglePath()) { + fullPath = statement.getFromPart(0).getPrefix() + SQLConstant.DOT + path; + } - // 匹配到了多个select表达式 - if (orderByExprSet.size() > 1) { - throw new SQLParserException(String.format("ORDER BY column '%s' is ambiguous.", path)); + for (Expression selectExpr : selectStatement.getExpressions()) { + if (StringUtils.match(fullPath, selectExpr.getColumnName())) { // 直接匹配select表达式 + orderByExprSet.add(new BaseExpression(fullPath)); + continue; } - - // ORDER BY的表达式没有出现在SELECT子句中 - if (orderByExprSet.isEmpty()) { - selectStatement.setOrderByExpr(expr); - String originPath = selectStatement.getOriginPath(path); - if (originPath != null) { - ((UnarySelectStatement) selectStatement).addOrderByPath(originPath); - } - } else { - selectStatement.setOrderByExpr(orderByExprSet.iterator().next()); + if (selectExpr.getAlias().equals(path)) { // 根据别名匹配select表达式 + orderByExprSet.add(selectExpr); } } + + // 匹配到了多个select表达式 + if (orderByExprSet.size() > 1) { + throw new SQLParserException(String.format("ORDER BY column '%s' is ambiguous.", path)); + } + + String originPath = selectStatement.getOriginPath(path); + if (originPath != null) { + ((UnarySelectStatement) selectStatement).addOrderByPath(originPath); + } + + // ORDER BY的表达式没有出现在SELECT子句中 + if (orderByExprSet.isEmpty()) { + selectStatement.setOrderByExpr(new BaseExpression(fullPath)); + } else { + selectStatement.setOrderByExpr(orderByExprSet.iterator().next()); + } } else { + assert ctx.expression() != null; + if (ctx.expression().subquery() != null) { + throw new SQLParserException("Subquery is not supported in ORDER BY columns."); + } + Expression expr = parseExpression(ctx.expression(), statement, false).get(0); MappingType type = ExpressionUtils.getExprMappingType(expr); if (type == MappingType.SetMapping || type == MappingType.Mapping) { throw new SQLParserException("ORDER BY column can not use SetToSet/SetToRow functions."); } - // 在SELECT子句中查找相同的表达式,避免重复计算(主要是case when) + // 在SELECT子句中查找相同的表达式,替换ORDER BY中的表达式,避免重复计算(主要是case when) boolean foundInSelect = false; for (Expression selectExpr : selectStatement.getExpressions()) { if (ExpressionUtils.getExprMappingType(selectExpr) == MappingType.RowMapping @@ -1538,8 +1546,7 @@ private void parseOrderItem(OrderItemContext ctx, SelectStatement selectStatemen ExpressionUtils.getBaseExpressionList(Collections.singletonList(expr), false); baseExpressions.forEach( baseExpression -> { - String path = baseExpression.getPathName(); - String originPath = selectStatement.getOriginPath(path); + String originPath = selectStatement.getOriginPath(baseExpression.getPathName()); if (originPath != null) { ((UnarySelectStatement) selectStatement).addOrderByPath(originPath); } From a8aeeb63b0bca49d661d8e115e1e43646feb3506 Mon Sep 17 00:00:00 2001 From: jzl18thu Date: Tue, 22 Oct 2024 11:27:40 +0800 Subject: [PATCH 5/6] add tests --- .../integration/func/sql/SQLSessionIT.java | 11 +++++++++++ .../iginx/integration/func/udf/UDFIT.java | 17 +++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java index 1ac46fd747..50e38dbdf2 100644 --- a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java +++ b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java @@ -6390,6 +6390,17 @@ public void testErrorClause() { errClause = "select s1 as key, s2 as key from us.d1;"; executor.executeAndCompareErrMsg( errClause, "Only one 'AS KEY' can be used in each select at most."); + + errClause = "select s1, s2 AS s1, count(s3) from us.d1 group by s1, s2;"; + executor.executeAndCompareErrMsg(errClause, "GROUP BY column 's1' is ambiguous."); + + errClause = "select s1, s2, count(s3) from us.d1 group by max(s1);"; + executor.executeAndCompareErrMsg( + errClause, "GROUP BY column can not use SetToSet/SetToRow functions."); + + errClause = "select s1, s2, count(s3) from us.d1 group by s1, s2 order by first(s1);"; + executor.executeAndCompareErrMsg( + errClause, "ORDER BY column can not use SetToSet/SetToRow functions."); } @Test diff --git a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/udf/UDFIT.java b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/udf/UDFIT.java index 5e1ba6a02f..b912fa82e8 100644 --- a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/udf/UDFIT.java +++ b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/udf/UDFIT.java @@ -735,6 +735,23 @@ public void testUDFGroupByAndOrderByExpr() { + "+---+-------+-------+\n" + "Total line number = 6\n"; compareResult(expected, ret.getResultInString(false, "")); + + query = "SELECT s1, s2 FROM test ORDER BY pow(s2, 2);"; + ret = tool.execute(query); + expected = + "ResultSets:\n" + + "+---+-------+-------+\n" + + "|key|test.s1|test.s2|\n" + + "+---+-------+-------+\n" + + "| 2| 3| 1|\n" + + "| 1| 2| 3|\n" + + "| 3| 2| 3|\n" + + "| 6| 0| 4|\n" + + "| 5| 3| 6|\n" + + "| 4| 3| 7|\n" + + "+---+-------+-------+\n" + + "Total line number = 6\n"; + compareResult(expected, ret.getResultInString(false, "")); } @Test From f1cecebad8e6f51384f0213d188e3d0938b56b66 Mon Sep 17 00:00:00 2001 From: jzl18thu Date: Tue, 22 Oct 2024 11:36:11 +0800 Subject: [PATCH 6/6] add test --- .../tsinghua/iginx/integration/func/udf/UDFIT.java | 11 +++++++++++ .../iginx/integration/func/udf/UDFTestTools.java | 11 +++++++++++ 2 files changed, 22 insertions(+) diff --git a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/udf/UDFIT.java b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/udf/UDFIT.java index b912fa82e8..7653a86d48 100644 --- a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/udf/UDFIT.java +++ b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/udf/UDFIT.java @@ -894,6 +894,17 @@ public void testUDFWithArgsAndKvArgs() { compareResult(expected, ret.getResultInString(false, "")); } + @Test + public void testErrorClause() { + String errClause = "select s1, s2, count(s3) from us.d1 group by reverse_rows(s1);"; + tool.executeAndCompareErrMsg( + errClause, "GROUP BY column can not use SetToSet/SetToRow functions."); + + errClause = "select s1, s2, count(s3) from us.d1 group by s1, s2 order by transpose(s1);"; + tool.executeAndCompareErrMsg( + errClause, "ORDER BY column can not use SetToSet/SetToRow functions."); + } + void compareResult(Object expected, Object actual) { if (!needCompareResult) { return; diff --git a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/udf/UDFTestTools.java b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/udf/UDFTestTools.java index 72d3006eef..fb8fbc5bd0 100644 --- a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/udf/UDFTestTools.java +++ b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/udf/UDFTestTools.java @@ -141,6 +141,17 @@ void executeFail(String statement) { fail("Statement: \"{}\" execute without failure, which was not expected."); } + public void executeAndCompareErrMsg(String statement, String expectedErrMsg) { + LOGGER.info("Execute Statement: \"{}\"", statement); + + try { + session.executeSql(statement); + } catch (SessionException e) { + LOGGER.info("Statement: \"{}\" execute fail. Because: ", statement, e); + assertEquals(expectedErrMsg, e.getMessage()); + } + } + boolean isUDFRegistered(String udfName) { SessionExecuteSqlResult ret = execute(SHOW_FUNCTION_SQL); List registerUDFs =