From ce6dc8cf6058500e9461aedda0e75176142cbc56 Mon Sep 17 00:00:00 2001 From: Sacha Viscaino Date: Mon, 31 Oct 2022 10:20:54 +0000 Subject: [PATCH] Add TypeParameters to SqlInvokedFunction --- ...SqlInvokedScalarFromAnnotationsParser.java | 8 ++ ...tAnnotationEngineForSqlInvokedScalars.java | 114 ++++++++++++++++++ .../spi/function/SqlInvokedFunction.java | 17 ++- 3 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForSqlInvokedScalars.java diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/SqlInvokedScalarFromAnnotationsParser.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/SqlInvokedScalarFromAnnotationsParser.java index 1dc7d107489ed..701394ecefcfe 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/SqlInvokedScalarFromAnnotationsParser.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/SqlInvokedScalarFromAnnotationsParser.java @@ -27,6 +27,8 @@ import com.facebook.presto.spi.function.SqlParameter; import com.facebook.presto.spi.function.SqlParameters; import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import com.facebook.presto.spi.function.TypeVariableConstraint; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -45,6 +47,7 @@ import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.NOT_DETERMINISTIC; import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.CALLED_ON_NULL_INPUT; import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT; +import static com.facebook.presto.spi.function.Signature.withVariadicBound; import static com.facebook.presto.util.Failures.checkCondition; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -156,10 +159,15 @@ else if (method.isAnnotationPresent(SqlParameters.class)) { throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, format("Failed to get function body for method [%s]", method), e); } + List typeVariableConstraints = stream(method.getAnnotationsByType(TypeParameter.class)) + .map(t -> withVariadicBound(t.value(), t.boundedBy().isEmpty() ? null : t.boundedBy())) + .collect(toImmutableList()); + return Stream.concat(Stream.of(functionHeader.value()), stream(functionHeader.alias())) .map(name -> new SqlInvokedFunction( QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, name), parameters, + typeVariableConstraints, returnType, functionDescription, routineCharacteristics, diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForSqlInvokedScalars.java b/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForSqlInvokedScalars.java new file mode 100644 index 0000000000000..0a545d8e58dd4 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForSqlInvokedScalars.java @@ -0,0 +1,114 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.TypeSignature; +import com.facebook.presto.operator.scalar.annotations.SqlInvokedScalarFromAnnotationsParser; +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.FunctionKind; +import com.facebook.presto.spi.function.Signature; +import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.function.SqlInvokedScalarFunction; +import com.facebook.presto.spi.function.SqlParameter; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import com.facebook.presto.spi.function.TypeVariableConstraint; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import java.util.Collections; +import java.util.List; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.DEFAULT_NAMESPACE; +import static com.facebook.presto.spi.function.SqlFunctionVisibility.PUBLIC; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestAnnotationEngineForSqlInvokedScalars + extends TestAnnotationEngine +{ + @Test + public void testParseFunctionDefinition() + { + Signature expectedSignature = new Signature( + QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "sample_sql_invoked_scalar_function"), + FunctionKind.SCALAR, + new ArrayType(BIGINT).getTypeSignature(), + ImmutableList.of(INTEGER.getTypeSignature())); + + List functions = SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(SingleImplementationSQLInvokedScalarFunction.class); + assertEquals(functions.size(), 1); + SqlInvokedFunction f = functions.get(0); + + assertEquals(f.getSignature(), expectedSignature); + assertTrue(f.isDeterministic()); + assertEquals(f.getVisibility(), PUBLIC); + assertEquals(f.getDescription(), "Simple SQL invoked scalar function"); + + assertEquals(f.getBody(), "RETURN SEQUENCE(1, input)"); + } + + @Test + public void testParseFunctionDefinitionWithTypeParameter() + { + Signature expectedSignature = new Signature( + QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "sample_sql_invoked_scalar_function_with_type_parameter"), + FunctionKind.SCALAR, + ImmutableList.of(new TypeVariableConstraint("T", false, false, null, false)), + Collections.emptyList(), + TypeSignature.parseTypeSignature("array(T)"), + ImmutableList.of(new TypeSignature("T")), + false); + + List functions = SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(SingleImplementationSQLInvokedScalarFunctionWithTypeParameter.class); + assertEquals(functions.size(), 1); + SqlInvokedFunction f = functions.get(0); + + assertEquals(f.getSignature(), expectedSignature); + assertTrue(f.isDeterministic()); + assertEquals(f.getVisibility(), PUBLIC); + assertEquals(f.getDescription(), "Simple SQL invoked scalar function with type parameter"); + + assertEquals(f.getBody(), "RETURN ARRAY[input]"); + } + + public static class SingleImplementationSQLInvokedScalarFunction + { + @SqlInvokedScalarFunction(value = "sample_sql_invoked_scalar_function", deterministic = true, calledOnNullInput = false) + @Description("Simple SQL invoked scalar function") + @SqlParameter(name = "input", type = "integer") + @SqlType("array") + public static String fun() + { + return "RETURN SEQUENCE(1, input)"; + } + } + + public static class SingleImplementationSQLInvokedScalarFunctionWithTypeParameter + { + @SqlInvokedScalarFunction(value = "sample_sql_invoked_scalar_function_with_type_parameter", deterministic = true, calledOnNullInput = false) + @Description("Simple SQL invoked scalar function with type parameter") + @TypeParameter("T") + @SqlParameter(name = "input", type = "T") + @SqlType("array") + public static String fun() + { + return "RETURN ARRAY[input]"; + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlInvokedFunction.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlInvokedFunction.java index 8709949dfcc2f..06bead006e5f2 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlInvokedFunction.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlInvokedFunction.java @@ -28,6 +28,7 @@ import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; import static com.facebook.presto.spi.function.SqlFunctionVisibility.PUBLIC; import static java.lang.String.format; +import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.collectingAndThen; import static java.util.stream.Collectors.joining; @@ -74,6 +75,19 @@ public SqlInvokedFunction( RoutineCharacteristics routineCharacteristics, String body, FunctionVersion version) + { + this(functionName, parameters, emptyList(), returnType, description, routineCharacteristics, body, version); + } + + public SqlInvokedFunction( + QualifiedObjectName functionName, + List parameters, + List typeVariableConstraints, + TypeSignature returnType, + String description, + RoutineCharacteristics routineCharacteristics, + String body, + FunctionVersion version) { this.parameters = requireNonNull(parameters, "parameters is null"); this.description = requireNonNull(description, "description is null"); @@ -83,7 +97,8 @@ public SqlInvokedFunction( List argumentTypes = parameters.stream() .map(Parameter::getType) .collect(collectingAndThen(toList(), Collections::unmodifiableList)); - this.signature = new Signature(functionName, SCALAR, returnType, argumentTypes); + + this.signature = new Signature(functionName, SCALAR, typeVariableConstraints, emptyList(), returnType, argumentTypes, false); this.functionId = new SqlFunctionId(functionName, argumentTypes); this.functionVersion = requireNonNull(version, "version is null"); this.functionHandle = version.hasVersion() ? Optional.of(new SqlFunctionHandle(this.functionId, version.toString())) : Optional.empty();