Skip to content

Commit

Permalink
Add TypeParameters to SqlInvokedFunction
Browse files Browse the repository at this point in the history
  • Loading branch information
Sacha Viscaino committed Oct 31, 2022
1 parent e97afc8 commit ce6dc8c
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import com.facebook.presto.spi.function.SqlParameter;
import com.facebook.presto.spi.function.SqlParameters;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;
import com.facebook.presto.spi.function.TypeVariableConstraint;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

Expand All @@ -45,6 +47,7 @@
import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.NOT_DETERMINISTIC;
import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.CALLED_ON_NULL_INPUT;
import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT;
import static com.facebook.presto.spi.function.Signature.withVariadicBound;
import static com.facebook.presto.util.Failures.checkCondition;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
Expand Down Expand Up @@ -156,10 +159,15 @@ else if (method.isAnnotationPresent(SqlParameters.class)) {
throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, format("Failed to get function body for method [%s]", method), e);
}

List<TypeVariableConstraint> typeVariableConstraints = stream(method.getAnnotationsByType(TypeParameter.class))
.map(t -> withVariadicBound(t.value(), t.boundedBy().isEmpty() ? null : t.boundedBy()))
.collect(toImmutableList());

return Stream.concat(Stream.of(functionHeader.value()), stream(functionHeader.alias()))
.map(name -> new SqlInvokedFunction(
QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, name),
parameters,
typeVariableConstraints,
returnType,
functionDescription,
routineCharacteristics,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator;

import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.operator.scalar.annotations.SqlInvokedScalarFromAnnotationsParser;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.FunctionKind;
import com.facebook.presto.spi.function.Signature;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.function.SqlInvokedScalarFunction;
import com.facebook.presto.spi.function.SqlParameter;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;
import com.facebook.presto.spi.function.TypeVariableConstraint;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.Test;

import java.util.Collections;
import java.util.List;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.DEFAULT_NAMESPACE;
import static com.facebook.presto.spi.function.SqlFunctionVisibility.PUBLIC;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;

public class TestAnnotationEngineForSqlInvokedScalars
extends TestAnnotationEngine
{
@Test
public void testParseFunctionDefinition()
{
Signature expectedSignature = new Signature(
QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "sample_sql_invoked_scalar_function"),
FunctionKind.SCALAR,
new ArrayType(BIGINT).getTypeSignature(),
ImmutableList.of(INTEGER.getTypeSignature()));

List<SqlInvokedFunction> functions = SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(SingleImplementationSQLInvokedScalarFunction.class);
assertEquals(functions.size(), 1);
SqlInvokedFunction f = functions.get(0);

assertEquals(f.getSignature(), expectedSignature);
assertTrue(f.isDeterministic());
assertEquals(f.getVisibility(), PUBLIC);
assertEquals(f.getDescription(), "Simple SQL invoked scalar function");

assertEquals(f.getBody(), "RETURN SEQUENCE(1, input)");
}

@Test
public void testParseFunctionDefinitionWithTypeParameter()
{
Signature expectedSignature = new Signature(
QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "sample_sql_invoked_scalar_function_with_type_parameter"),
FunctionKind.SCALAR,
ImmutableList.of(new TypeVariableConstraint("T", false, false, null, false)),
Collections.emptyList(),
TypeSignature.parseTypeSignature("array(T)"),
ImmutableList.of(new TypeSignature("T")),
false);

List<SqlInvokedFunction> functions = SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(SingleImplementationSQLInvokedScalarFunctionWithTypeParameter.class);
assertEquals(functions.size(), 1);
SqlInvokedFunction f = functions.get(0);

assertEquals(f.getSignature(), expectedSignature);
assertTrue(f.isDeterministic());
assertEquals(f.getVisibility(), PUBLIC);
assertEquals(f.getDescription(), "Simple SQL invoked scalar function with type parameter");

assertEquals(f.getBody(), "RETURN ARRAY[input]");
}

public static class SingleImplementationSQLInvokedScalarFunction
{
@SqlInvokedScalarFunction(value = "sample_sql_invoked_scalar_function", deterministic = true, calledOnNullInput = false)
@Description("Simple SQL invoked scalar function")
@SqlParameter(name = "input", type = "integer")
@SqlType("array<bigint>")
public static String fun()
{
return "RETURN SEQUENCE(1, input)";
}
}

public static class SingleImplementationSQLInvokedScalarFunctionWithTypeParameter
{
@SqlInvokedScalarFunction(value = "sample_sql_invoked_scalar_function_with_type_parameter", deterministic = true, calledOnNullInput = false)
@Description("Simple SQL invoked scalar function with type parameter")
@TypeParameter("T")
@SqlParameter(name = "input", type = "T")
@SqlType("array<T>")
public static String fun()
{
return "RETURN ARRAY[input]";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import static com.facebook.presto.spi.function.FunctionVersion.notVersioned;
import static com.facebook.presto.spi.function.SqlFunctionVisibility.PUBLIC;
import static java.lang.String.format;
import static java.util.Collections.emptyList;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.joining;
Expand Down Expand Up @@ -74,6 +75,19 @@ public SqlInvokedFunction(
RoutineCharacteristics routineCharacteristics,
String body,
FunctionVersion version)
{
this(functionName, parameters, emptyList(), returnType, description, routineCharacteristics, body, version);
}

public SqlInvokedFunction(
QualifiedObjectName functionName,
List<Parameter> parameters,
List<TypeVariableConstraint> typeVariableConstraints,
TypeSignature returnType,
String description,
RoutineCharacteristics routineCharacteristics,
String body,
FunctionVersion version)
{
this.parameters = requireNonNull(parameters, "parameters is null");
this.description = requireNonNull(description, "description is null");
Expand All @@ -83,7 +97,8 @@ public SqlInvokedFunction(
List<TypeSignature> argumentTypes = parameters.stream()
.map(Parameter::getType)
.collect(collectingAndThen(toList(), Collections::unmodifiableList));
this.signature = new Signature(functionName, SCALAR, returnType, argumentTypes);

this.signature = new Signature(functionName, SCALAR, typeVariableConstraints, emptyList(), returnType, argumentTypes, false);
this.functionId = new SqlFunctionId(functionName, argumentTypes);
this.functionVersion = requireNonNull(version, "version is null");
this.functionHandle = version.hasVersion() ? Optional.of(new SqlFunctionHandle(this.functionId, version.toString())) : Optional.empty();
Expand Down

0 comments on commit ce6dc8c

Please sign in to comment.