Skip to content

Commit

Permalink
[GLUTEN-1490] Refactor Substrait literals using generics, and support…
Browse files Browse the repository at this point in the history
… map/struct/array literals (facebookincubator#1494)

Refactor substrait literals using generics, and support complex type
literals based on it.
  • Loading branch information
taiyang-li authored May 18, 2023
1 parent ffd4114 commit 82e7c41
Show file tree
Hide file tree
Showing 47 changed files with 656 additions and 679 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,31 @@ class GlutenClickHouseTPCHParquetSuite extends GlutenClickHouseTPCHAbstractSuite
"from lineitem order by l_shipdate limit 10;")(checkOperatorMatch[ProjectExecTransformer])
}

test("test literals") {
val query = """
SELECT
CAST(NULL AS BOOLEAN) AS boolean_literal,
CAST(1 AS TINYINT) AS tinyint_literal,
CAST(2 AS SMALLINT) AS smallint_literal,
CAST(3 AS INTEGER) AS integer_literal,
CAST(4 AS BIGINT) AS bigint_literal,
CAST(5.5 AS FLOAT) AS float_literal,
CAST(6.6 AS DOUBLE) AS double_literal,
CAST('7' AS STRING) AS string_literal,
DATE '2022-01-01' AS date_literal,
TIMESTAMP '2022-01-01 10:00:00' AS timestamp_literal,
CAST(X'48656C6C6F' AS BINARY) AS binary_literal,
ARRAY(1, 2, 3, 4) AS array_literal,
MAP("a", 1, "b", 2) AS map_literal,
STRUCT("hello", 123) AS struct_literal,
ARRAY() as empty_array_literal,
MAP() as empty_map_literal,
ARRAY(1, NULL, 3) as array_with_null_literal,
MAP(1, 2, CAST(3 as SHORT), null) as map_with_null_literal
from range(10)"""
runQueryAndCompare(query)(checkOperatorMatch[ProjectExecTransformer])
}

// see issue https://github.com/Kyligence/ClickHouse/issues/93
ignore("TPCH Q22") {
runTPCHQuery(22) { df => }
Expand Down
130 changes: 96 additions & 34 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,40 @@
#include <Builder/BroadCastJoinBuilder.h>
#include <Columns/ColumnSet.h>
#include <Core/Block.h>
#include <Core/ColumnWithTypeAndName.h>
#include <Core/Names.h>
#include <Core/NamesAndTypes.h>
#include <Core/Types.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeDate32.h>
#include <DataTypes/DataTypeDateTime64.h>
#include <DataTypes/DataTypeFactory.h>
#include <DataTypes/DataTypeFixedString.h>
#include <DataTypes/DataTypeMap.h>
#include <DataTypes/DataTypeNothing.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeSet.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/IDataType.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <DataTypes/getLeastSupertype.h>
#include <Functions/CastOverloadResolver.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionsConversion.h>
#include <Functions/registerFunctions.h>
#include <Interpreters/ActionsDAG.h>
#include <Interpreters/ActionsVisitor.h>
#include <Interpreters/CollectJoinOnKeysVisitor.h>
#include <Interpreters/Context.h>
#include <Interpreters/HashJoin.h>
#include <Operator/PartitionColumnFillingTransform.h>
#include <Parser/RelParser.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ExpressionListParsers.h>
#include <Processors/Executors/PullingAsyncPipelineExecutor.h>
#include <Processors/Formats/Impl/ArrowBlockOutputFormat.h>
#include <Processors/Formats/Impl/ParquetBlockInputFormat.h>
Expand All @@ -39,19 +50,25 @@
#include <Processors/QueryPlan/LimitStep.h>
#include <Processors/QueryPlan/MergingAggregatedStep.h>
#include <Processors/QueryPlan/Optimizations/QueryPlanOptimizationSettings.h>
#include <Processors/QueryPlan/QueryPlan.h>
#include <Processors/QueryPlan/ReadFromPreparedSource.h>
#include <Processors/QueryPlan/SortingStep.h>
#include <Processors/Transforms/AggregatingTransform.h>
#include <QueryPipeline/Pipe.h>
#include <QueryPipeline/QueryPipelineBuilder.h>
#include <Storages/CustomStorageMergeTree.h>
#include <Storages/IStorage.h>
#include <Storages/MergeTree/MergeTreeData.h>
#include <Storages/StorageMergeTreeFactory.h>
#include <Storages/SubstraitSource/SubstraitFileSource.h>
#include <base/Decimal.h>
#include <base/types.h>
#include <google/protobuf/util/json_util.h>
#include <google/protobuf/wrappers.pb.h>
#include <sys/select.h>
#include <Poco/StringTokenizer.h>
#include <Poco/Util/MapConfiguration.h>
#include <Common/CHUtil.h>
#include <Common/DebugUtils.h>
#include <Common/Exception.h>
#include <Common/JoinHelper.h>
Expand All @@ -60,24 +77,6 @@
#include <Common/logger_useful.h>
#include <Common/typeid_cast.h>

#include <Core/ColumnWithTypeAndName.h>
#include <Core/Types.h>
#include <DataTypes/DataTypeFactory.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/IDataType.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <Functions/CastOverloadResolver.h>
#include <Functions/FunctionsConversion.h>
#include <Parser/RelParser.h>
#include <Parsers/ExpressionListParsers.h>
#include <Processors/QueryPlan/QueryPlan.h>
#include <Processors/QueryPlan/SortingStep.h>
#include <Storages/IStorage.h>
#include <base/types.h>
#include <sys/select.h>
#include <Common/CHUtil.h>
#include "SerializedPlanParser.h"

namespace DB
{
namespace ErrorCodes
Expand Down Expand Up @@ -1908,33 +1907,96 @@ std::pair<DataTypePtr, Field> SerializedPlanParser::parseLiteral(const substrait
throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support decimal type with precision {}", precision);
break;
}
/// TODO(taiyang-li) Other type: Struct/Map/List
case substrait::Expression_Literal::kList: {
/// TODO(taiyang-li) Implement empty list
if (literal.has_empty_list())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Empty list not support!");

DataTypePtr first_type;
std::tie(first_type, std::ignore) = parseLiteral(literal.list().values(0));
const auto & values = literal.list().values();
if (values.empty())
{
type = std::make_shared<DataTypeArray>(std::make_shared<DataTypeNothing>());
field = Array();
break;
}

size_t list_len = literal.list().values_size();
DataTypePtr common_type;
std::tie(common_type, std::ignore) = parseLiteral(values[0]);
size_t list_len = values.size();
Array array(list_len);
for (size_t i = 0; i < list_len; ++i)
{
auto type_and_field = std::move(parseLiteral(literal.list().values(i)));
if (!first_type->equals(*type_and_field.first))
throw Exception(
ErrorCodes::LOGICAL_ERROR,
"Literal list type mismatch:{} and {}",
first_type->getName(),
type_and_field.first->getName());
auto type_and_field = parseLiteral(values[i]);
common_type = getLeastSupertype(DataTypes{common_type, type_and_field.first});
array[i] = std::move(type_and_field.second);
}

type = std::make_shared<DataTypeArray>(first_type);
type = std::make_shared<DataTypeArray>(common_type);
field = std::move(array);
break;
}
case substrait::Expression_Literal::kMap: {
const auto & key_values = literal.map().key_values();
if (key_values.empty())
{
type = std::make_shared<DataTypeMap>(std::make_shared<DataTypeNothing>(), std::make_shared<DataTypeNothing>());
field = Map();
break;
}

const auto & first_key_value = key_values[0];

DataTypePtr common_key_type;
std::tie(common_key_type, std::ignore) = parseLiteral(first_key_value.key());

DataTypePtr common_value_type;
std::tie(common_value_type, std::ignore) = parseLiteral(first_key_value.value());

Map map;
map.reserve(key_values.size());
for (int i = 0; i < key_values.size(); ++i)
{
Tuple tuple(2);

DataTypePtr key_type;
std::tie(key_type, tuple[0]) = parseLiteral(key_values[i].key());
/// Each key should has the same type
if (!common_key_type->equals(*key_type))
throw Exception(
ErrorCodes::LOGICAL_ERROR,
"Literal map key type mismatch:{} and {}",
common_key_type->getName(),
key_type->getName());

DataTypePtr value_type;
std::tie(value_type, tuple[1]) = parseLiteral(key_values[i].value());
/// Each value should has least super type for all of them
common_value_type = getLeastSupertype(DataTypes{common_value_type, value_type});

map.emplace_back(std::move(tuple));
}

type = std::make_shared<DataTypeMap>(common_key_type, common_value_type);
field = std::move(map);
break;
}
case substrait::Expression_Literal::kStruct: {
const auto & fields = literal.struct_().fields();

DataTypes types;
types.reserve(fields.size());
Tuple tuple;
tuple.reserve(fields.size());
for (const auto & f : fields)
{
DataTypePtr field_type;
Field field_value;
std::tie(field_type, field_value) = parseLiteral(f);

types.emplace_back(std::move(field_type));
tuple.emplace_back(std::move(field_value));
}

type = std::make_shared<DataTypeTuple>(types);
field = std::move(tuple);
break;
}
case substrait::Expression_Literal::kNull: {
type = parseType(literal.null());
field = std::move(Field{});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,23 @@

package io.glutenproject.substrait.expression;

import io.substrait.proto.Expression;

import java.io.Serializable;
import io.glutenproject.substrait.type.BinaryTypeNode;
import io.glutenproject.substrait.type.TypeNode;
import io.substrait.proto.Expression.Literal.Builder;

import com.google.protobuf.ByteString;

public class BinaryLiteralNode implements ExpressionNode, Serializable {
private final ByteString value;

public class BinaryLiteralNode extends LiteralNodeWithValue<byte[]> {
public BinaryLiteralNode(byte[] value) {
this.value = ByteString.copyFrom(value);
super(value, new BinaryTypeNode(true));
}

@Override
public Expression toProtobuf() {
Expression.Literal.Builder binaryBuilder =
Expression.Literal.newBuilder();
binaryBuilder.setBinary(value);
public BinaryLiteralNode(byte[] value, TypeNode typeNode) {
super(value, typeNode);
}

Expression.Builder builder = Expression.newBuilder();
builder.setLiteral(binaryBuilder.build());
return builder.build();
@Override
protected void updateLiteralBuilder(Builder literalBuilder, byte[] value) {
literalBuilder.setBinary(ByteString.copyFrom(value));
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,22 @@

package io.glutenproject.substrait.expression;

import io.substrait.proto.Expression;
import io.glutenproject.substrait.type.BooleanTypeNode;
import io.glutenproject.substrait.type.TypeNode;

import java.io.Serializable;

public class BooleanLiteralNode implements ExpressionNode, Serializable {
private final Boolean value;
import io.substrait.proto.Expression.Literal.Builder;

public class BooleanLiteralNode extends LiteralNodeWithValue<Boolean> {
public BooleanLiteralNode(Boolean value) {
this.value = value;
super(value, new BooleanTypeNode(true));
}

@Override
public Expression toProtobuf() {
Expression.Literal.Builder booleanBuilder =
Expression.Literal.newBuilder();
booleanBuilder.setBoolean(value);
public BooleanLiteralNode(Boolean value, TypeNode typeNode) {
super(value, typeNode);
}

Expression.Builder builder = Expression.newBuilder();
builder.setLiteral(booleanBuilder.build());
return builder.build();
@Override
protected void updateLiteralBuilder(Builder literalBuilder, Boolean value) {
literalBuilder.setBoolean(value);
}
}
Loading

0 comments on commit 82e7c41

Please sign in to comment.