From b773c2ebfecf8eaf8a7f2a2e137f2807b6c27211 Mon Sep 17 00:00:00 2001 From: auden-woolfson Date: Wed, 13 Nov 2024 15:00:57 -0800 Subject: [PATCH] add support for multiple argument types --- .../translator/TranslatorAnnotationParser.java | 13 ++++++++++--- .../sql/relational/TestRowExpressionTranslator.java | 6 ++++-- .../facebook/presto/spi/function/SqlSignature.java | 3 ++- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/TranslatorAnnotationParser.java b/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/TranslatorAnnotationParser.java index acc914fc23759..63505efff7165 100644 --- a/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/TranslatorAnnotationParser.java +++ b/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/TranslatorAnnotationParser.java @@ -145,9 +145,16 @@ private static List methodToFunctionMetadata(ScalarTranslation for (SqlSignature signature : signatures) { ImmutableList.Builder argumentTypes = new ImmutableList.Builder<>(); - TypeSignature argumentType = parseTypeSignature(signature.argumentType()); - for (int i = 0; i < method.getParameterCount(); i++) { - argumentTypes.add(argumentType); + if (signature.argumentTypes().length == method.getParameterCount()) { + for (int i = 0; i < method.getParameterCount(); i++) { + argumentTypes.add(parseTypeSignature(signature.argumentTypes()[i])); + } + } + else { + TypeSignature argumentType = parseTypeSignature(signature.argumentType()); + for (int i = 0; i < method.getParameterCount(); i++) { + argumentTypes.add(argumentType); + } } TypeSignature returnType = parseTypeSignature(signature.returnType()); FunctionMetadata derivedMetadata; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/relational/TestRowExpressionTranslator.java b/presto-main/src/test/java/com/facebook/presto/sql/relational/TestRowExpressionTranslator.java index cc81387bb67a1..d27511b15ffe4 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/relational/TestRowExpressionTranslator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/relational/TestRowExpressionTranslator.java @@ -158,7 +158,8 @@ public Object[][] createTestBasicOperatorData() return new Object[][] { {BIGINT, BIGINT}, {INTEGER, INTEGER}, - {decimalType, decimalType} + {decimalType, decimalType}, + {BIGINT, INTEGER} }; } @@ -299,7 +300,8 @@ public static String not(String sql) @SupportedSignatures({ @SqlSignature(argumentType = StandardTypes.INTEGER, returnType = StandardTypes.INTEGER), @SqlSignature(argumentType = StandardTypes.BIGINT, returnType = StandardTypes.BIGINT), - @SqlSignature(argumentType = "decimal(38, 0)", returnType = "decimal(38, 0)")}) + @SqlSignature(argumentType = "decimal(38, 0)", returnType = "decimal(38, 0)"), + @SqlSignature(argumentTypes = {StandardTypes.BIGINT, StandardTypes.INTEGER}, returnType = StandardTypes.BIGINT)}) public static String plus(String left, String right) { return left + " -|- " + right; diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlSignature.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlSignature.java index 5bd5935543540..aba4a8ccfbece 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlSignature.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlSignature.java @@ -22,8 +22,9 @@ @Retention(RUNTIME) @Target(METHOD) public @interface SqlSignature { - String argumentType(); + String argumentType() default ""; String returnType(); + String[] argumentTypes() default {}; Class nativeContainerType() default Object.class; }