Skip to content

Commit 55060ed

Browse files
committed
[native] Add additional function metadata to function signatures endpoint
1 parent 54de1aa commit 55060ed

File tree

19 files changed

+586
-428
lines changed

19 files changed

+586
-428
lines changed

presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/JsonBasedUdfFunctionMetadata.java

+27-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import com.facebook.presto.spi.function.FunctionKind;
1919
import com.facebook.presto.spi.function.RoutineCharacteristics;
2020
import com.facebook.presto.spi.function.SqlFunctionId;
21+
import com.facebook.presto.spi.function.TypeVariableConstraint;
2122
import com.fasterxml.jackson.annotation.JsonCreator;
2223
import com.fasterxml.jackson.annotation.JsonIgnore;
2324
import com.fasterxml.jackson.annotation.JsonProperty;
@@ -63,6 +64,15 @@ public class JsonBasedUdfFunctionMetadata
6364
* Optional Aggregate-specific metadata (required for aggregation functions)
6465
*/
6566
private final Optional<AggregationFunctionMetadata> aggregateMetadata;
67+
/**
68+
* Marked to indicate whether it is a variable arity function.
69+
* A variable arity function can have a variable number of arguments of the specified type.
70+
*/
71+
private final boolean variableArity;
72+
/**
73+
* Optional list of the typeVariableConstraints.
74+
*/
75+
private final Optional<List<TypeVariableConstraint>> typeVariableConstraints;
6676
private final Optional<SqlFunctionId> functionId;
6777
private final Optional<String> version;
6878

@@ -73,23 +83,27 @@ public JsonBasedUdfFunctionMetadata(
7383
@JsonProperty("outputType") TypeSignature outputType,
7484
@JsonProperty("paramTypes") List<TypeSignature> paramTypes,
7585
@JsonProperty("schema") String schema,
86+
@JsonProperty("variableArity") boolean variableArity,
7687
@JsonProperty("routineCharacteristics") RoutineCharacteristics routineCharacteristics,
7788
@JsonProperty("aggregateMetadata") Optional<AggregationFunctionMetadata> aggregateMetadata,
7889
@JsonProperty("functionId") Optional<SqlFunctionId> functionId,
79-
@JsonProperty("version") Optional<String> version)
90+
@JsonProperty("version") Optional<String> version,
91+
@JsonProperty("typeVariableConstraints") Optional<List<TypeVariableConstraint>> typeVariableConstraints)
8092
{
8193
this.docString = requireNonNull(docString, "docString is null");
8294
this.functionKind = requireNonNull(functionKind, "functionKind is null");
8395
this.outputType = requireNonNull(outputType, "outputType is null");
8496
this.paramTypes = ImmutableList.copyOf(requireNonNull(paramTypes, "paramTypes is null"));
8597
this.schema = requireNonNull(schema, "schema is null");
98+
this.variableArity = variableArity;
8699
this.routineCharacteristics = requireNonNull(routineCharacteristics, "routineCharacteristics is null");
87100
this.aggregateMetadata = requireNonNull(aggregateMetadata, "aggregateMetadata is null");
88101
checkArgument(
89102
(functionKind == AGGREGATE && aggregateMetadata.isPresent()) || (functionKind != AGGREGATE && !aggregateMetadata.isPresent()),
90103
"aggregateMetadata must be present for aggregation functions and absent otherwise");
91104
this.functionId = requireNonNull(functionId, "functionId is null");
92105
this.version = requireNonNull(version, "version is null");
106+
this.typeVariableConstraints = requireNonNull(typeVariableConstraints, "typeVariableConstraints is null");
93107
}
94108

95109
@JsonProperty
@@ -128,6 +142,12 @@ public String getSchema()
128142
return schema;
129143
}
130144

145+
@JsonProperty
146+
public boolean getVariableArity()
147+
{
148+
return variableArity;
149+
}
150+
131151
@JsonProperty
132152
public RoutineCharacteristics getRoutineCharacteristics()
133153
{
@@ -151,4 +171,10 @@ public Optional<String> getVersion()
151171
{
152172
return version;
153173
}
174+
175+
@JsonProperty
176+
public Optional<List<TypeVariableConstraint>> getTypeVariableConstraints()
177+
{
178+
return typeVariableConstraints;
179+
}
154180
}

presto-function-namespace-managers/src/test/java/com/facebook/presto/functionNamespace/TestRestBasedFunctionNamespaceManager.java

+28-9
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import java.util.Optional;
5959

6060
import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
61+
import static java.util.Collections.emptyList;
6162
import static org.testng.Assert.assertEquals;
6263
import static org.testng.Assert.assertNotNull;
6364

@@ -140,20 +141,24 @@ public static Map<String, List<JsonBasedUdfFunctionMetadata>> createUdfSignature
140141
new TypeSignature("integer"),
141142
Collections.singletonList(new TypeSignature("integer")),
142143
"default",
144+
false,
143145
new RoutineCharacteristics(RoutineCharacteristics.Language.CPP, RoutineCharacteristics.Determinism.DETERMINISTIC, RoutineCharacteristics.NullCallClause.CALLED_ON_NULL_INPUT),
144146
Optional.empty(),
145147
Optional.of(new SqlFunctionId(QualifiedObjectName.valueOf("unittest.default.square"), ImmutableList.of(parseTypeSignature("integer")))),
146-
Optional.of("1")));
148+
Optional.of("1"),
149+
Optional.of(emptyList())));
147150
squareFunctions.add(new JsonBasedUdfFunctionMetadata(
148151
"square a double",
149152
FunctionKind.SCALAR,
150153
new TypeSignature("double"),
151154
Collections.singletonList(new TypeSignature("double")),
152155
"test_schema",
156+
false,
153157
new RoutineCharacteristics(RoutineCharacteristics.Language.CPP, RoutineCharacteristics.Determinism.DETERMINISTIC, RoutineCharacteristics.NullCallClause.CALLED_ON_NULL_INPUT),
154158
Optional.empty(),
155159
Optional.of(new SqlFunctionId(QualifiedObjectName.valueOf("unittest.test_schema.square"), ImmutableList.of(parseTypeSignature("double")))),
156-
Optional.of("1")));
160+
Optional.of("1"),
161+
Optional.of(emptyList())));
157162
udfSignatureMap.put("square", squareFunctions);
158163

159164
// array_function_1
@@ -164,30 +169,36 @@ public static Map<String, List<JsonBasedUdfFunctionMetadata>> createUdfSignature
164169
parseTypeSignature("ARRAY<ARRAY<BOOLEAN>>"),
165170
Arrays.asList(parseTypeSignature("ARRAY<ARRAY<BOOLEAN>>"), parseTypeSignature("ARRAY<ARRAY<BOOLEAN>>")),
166171
"default",
172+
false,
167173
new RoutineCharacteristics(RoutineCharacteristics.Language.CPP, RoutineCharacteristics.Determinism.DETERMINISTIC, RoutineCharacteristics.NullCallClause.CALLED_ON_NULL_INPUT),
168174
Optional.empty(),
169175
Optional.of(new SqlFunctionId(QualifiedObjectName.valueOf("unittest.default.array_function_1"), ImmutableList.of(parseTypeSignature("ARRAY<ARRAY<BOOLEAN>>"), parseTypeSignature("ARRAY<ARRAY<BOOLEAN>>")))),
170-
Optional.of("1")));
176+
Optional.of("1"),
177+
Optional.of(emptyList())));
171178
arrayFunction1.add(new JsonBasedUdfFunctionMetadata(
172179
"combines two float arrays into one",
173180
FunctionKind.SCALAR,
174181
parseTypeSignature("ARRAY<ARRAY<BIGINT>>"),
175182
Arrays.asList(parseTypeSignature("ARRAY<ARRAY<BIGINT>>"), parseTypeSignature("ARRAY<ARRAY<BIGINT>>")),
176183
"test_schema",
184+
false,
177185
new RoutineCharacteristics(RoutineCharacteristics.Language.CPP, RoutineCharacteristics.Determinism.DETERMINISTIC, RoutineCharacteristics.NullCallClause.CALLED_ON_NULL_INPUT),
178186
Optional.empty(),
179187
Optional.of(new SqlFunctionId(QualifiedObjectName.valueOf("unittest.test_schema.array_function_1"), ImmutableList.of(parseTypeSignature("ARRAY<ARRAY<BIGINT>>"), parseTypeSignature("ARRAY<ARRAY<BIGINT>>")))),
180-
Optional.of("1")));
188+
Optional.of("1"),
189+
Optional.of(emptyList())));
181190
arrayFunction1.add(new JsonBasedUdfFunctionMetadata(
182191
"combines two double arrays into one",
183192
FunctionKind.SCALAR,
184193
parseTypeSignature("ARRAY<DOUBLE>"),
185194
Arrays.asList(parseTypeSignature("ARRAY<DOUBLE>"), TypeSignature.parseTypeSignature("ARRAY<DOUBLE>")),
186195
"test_schema",
196+
false,
187197
new RoutineCharacteristics(RoutineCharacteristics.Language.CPP, RoutineCharacteristics.Determinism.DETERMINISTIC, RoutineCharacteristics.NullCallClause.CALLED_ON_NULL_INPUT),
188198
Optional.empty(),
189199
Optional.of(new SqlFunctionId(QualifiedObjectName.valueOf("unittest.test_schema.array_function_1"), ImmutableList.of(parseTypeSignature("ARRAY<DOUBLE>"), parseTypeSignature("ARRAY<DOUBLE>")))),
190-
Optional.of("1")));
200+
Optional.of("1"),
201+
Optional.of(emptyList())));
191202
udfSignatureMap.put("array_function_1", arrayFunction1);
192203

193204
// array_function_2
@@ -198,20 +209,24 @@ public static Map<String, List<JsonBasedUdfFunctionMetadata>> createUdfSignature
198209
TypeSignature.parseTypeSignature("ARRAY<map<BIGINT, DOUBLE>>"),
199210
Arrays.asList(TypeSignature.parseTypeSignature("ARRAY<map<BIGINT, DOUBLE>>"), TypeSignature.parseTypeSignature("ARRAY<varchar>")),
200211
"default",
212+
false,
201213
new RoutineCharacteristics(RoutineCharacteristics.Language.CPP, RoutineCharacteristics.Determinism.DETERMINISTIC, RoutineCharacteristics.NullCallClause.CALLED_ON_NULL_INPUT),
202214
Optional.empty(),
203215
Optional.of(new SqlFunctionId(QualifiedObjectName.valueOf("unittest.default.array_function_2"), ImmutableList.of(parseTypeSignature("ARRAY<map<BIGINT, DOUBLE>>"), parseTypeSignature("ARRAY<varchar>")))),
204-
Optional.of("1")));
216+
Optional.of("1"),
217+
Optional.of(emptyList())));
205218
arrayFunction2.add(new JsonBasedUdfFunctionMetadata(
206219
"transforms inputs into the output",
207220
FunctionKind.SCALAR,
208221
TypeSignature.parseTypeSignature("ARRAY<map<BIGINT, DOUBLE>>"),
209222
Arrays.asList(TypeSignature.parseTypeSignature("ARRAY<map<BIGINT, DOUBLE>>"), TypeSignature.parseTypeSignature("ARRAY<ARRAY<BOOLEAN>>"), TypeSignature.parseTypeSignature("ARRAY<varchar>")),
210223
"test_schema",
224+
false,
211225
new RoutineCharacteristics(RoutineCharacteristics.Language.CPP, RoutineCharacteristics.Determinism.DETERMINISTIC, RoutineCharacteristics.NullCallClause.CALLED_ON_NULL_INPUT),
212226
Optional.empty(),
213227
Optional.of(new SqlFunctionId(QualifiedObjectName.valueOf("unittest.test_schema.array_function_2"), ImmutableList.of(parseTypeSignature("ARRAY<map<BIGINT, DOUBLE>>"), parseTypeSignature("ARRAY<ARRAY<BOOLEAN>>"), parseTypeSignature("ARRAY<varchar>")))),
214-
Optional.of("1")));
228+
Optional.of("1"),
229+
Optional.of(emptyList())));
215230
udfSignatureMap.put("array_function_2", arrayFunction2);
216231

217232
return udfSignatureMap;
@@ -229,20 +244,24 @@ public static Map<String, List<JsonBasedUdfFunctionMetadata>> createUpdatedUdfSi
229244
new TypeSignature("integer"),
230245
Collections.singletonList(new TypeSignature("integer")),
231246
"default",
247+
false,
232248
new RoutineCharacteristics(RoutineCharacteristics.Language.CPP, RoutineCharacteristics.Determinism.DETERMINISTIC, RoutineCharacteristics.NullCallClause.CALLED_ON_NULL_INPUT),
233249
Optional.empty(),
234250
Optional.of(new SqlFunctionId(QualifiedObjectName.valueOf("unittest.default.square"), ImmutableList.of(parseTypeSignature("integer")))),
235-
Optional.of("1")));
251+
Optional.of("1"),
252+
Optional.of(emptyList())));
236253
squareFunctions.add(new JsonBasedUdfFunctionMetadata(
237254
"square a double",
238255
FunctionKind.SCALAR,
239256
new TypeSignature("double"),
240257
Collections.singletonList(new TypeSignature("double")),
241258
"test_schema",
259+
false,
242260
new RoutineCharacteristics(RoutineCharacteristics.Language.CPP, RoutineCharacteristics.Determinism.DETERMINISTIC, RoutineCharacteristics.NullCallClause.CALLED_ON_NULL_INPUT),
243261
Optional.empty(),
244262
Optional.of(new SqlFunctionId(QualifiedObjectName.valueOf("unittest.test_schema.square"), ImmutableList.of(parseTypeSignature("double")))),
245-
Optional.of("1")));
263+
Optional.of("1"),
264+
Optional.of(emptyList())));
246265
udfSignatureMap.put("square", squareFunctions);
247266

248267
return udfSignatureMap;

presto-function-server/src/main/java/com/facebook/presto/server/FunctionResource.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ private static JsonBasedUdfFunctionMetadata sqlFunctionToMetadata(SqlFunction fu
129129
function.getSignature().getReturnType(),
130130
function.getSignature().getArgumentTypes(),
131131
function.getSignature().getName().getSchemaName(),
132+
function.getSignature().isVariableArity(),
132133
new RoutineCharacteristics(
133134
JAVA,
134135
function.isDeterministic() ? DETERMINISTIC : NOT_DETERMINISTIC,
@@ -138,7 +139,8 @@ private static JsonBasedUdfFunctionMetadata sqlFunctionToMetadata(SqlFunction fu
138139
new SqlFunctionId(
139140
function.getSignature().getName(),
140141
function.getSignature().getArgumentTypes())),
141-
Optional.of("1"));
142+
Optional.of("1"),
143+
Optional.of(function.getSignature().getTypeVariableConstraints()));
142144
}
143145

144146
@GET

presto-native-execution/presto_cpp/main/types/FunctionMetadata.cpp

+43-4
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ const protocol::AggregationFunctionMetadata getAggregationFunctionMetadata(
8686
const std::string& name,
8787
const AggregateFunctionSignature& signature) {
8888
protocol::AggregationFunctionMetadata metadata;
89-
metadata.intermediateType = signature.intermediateType().toString();
89+
metadata.intermediateType =
90+
boost::algorithm::to_lower_copy(signature.intermediateType().toString());
9091
metadata.isOrderSensitive =
9192
getAggregateFunctionEntry(name)->metadata.orderSensitive;
9293
return metadata;
@@ -140,6 +141,24 @@ const protocol::RoutineCharacteristics getRoutineCharacteristics(
140141
return routineCharacteristics;
141142
}
142143

144+
const std::vector<protocol::TypeVariableConstraint> getTypeVariableConstraints(
145+
const FunctionSignature& functionSignature) {
146+
std::vector<protocol::TypeVariableConstraint> typeVariableConstraints;
147+
const auto functionVariables = functionSignature.variables();
148+
for (const auto& [name, signature] : functionVariables) {
149+
if (signature.isTypeParameter()) {
150+
protocol::TypeVariableConstraint typeVariableConstraint;
151+
typeVariableConstraint.name =
152+
boost::algorithm::to_lower_copy(signature.name());
153+
typeVariableConstraint.orderableRequired = signature.orderableTypesOnly();
154+
typeVariableConstraint.comparableRequired =
155+
signature.comparableTypesOnly();
156+
typeVariableConstraints.emplace_back(typeVariableConstraint);
157+
}
158+
}
159+
return typeVariableConstraints;
160+
}
161+
143162
std::optional<protocol::JsonBasedUdfFunctionMetadata> buildFunctionMetadata(
144163
const std::string& name,
145164
const std::string& schema,
@@ -152,19 +171,25 @@ std::optional<protocol::JsonBasedUdfFunctionMetadata> buildFunctionMetadata(
152171
if (!isValidPrestoType(signature.returnType())) {
153172
return std::nullopt;
154173
}
155-
metadata.outputType = signature.returnType().toString();
174+
metadata.outputType =
175+
boost::algorithm::to_lower_copy(signature.returnType().toString());
156176

157177
const auto& argumentTypes = signature.argumentTypes();
158178
std::vector<std::string> paramTypes(argumentTypes.size());
159179
for (auto i = 0; i < argumentTypes.size(); i++) {
160180
if (!isValidPrestoType(argumentTypes.at(i))) {
161181
return std::nullopt;
162182
}
163-
paramTypes[i] = argumentTypes.at(i).toString();
183+
paramTypes[i] =
184+
boost::algorithm::to_lower_copy(argumentTypes.at(i).toString());
164185
}
165186
metadata.paramTypes = paramTypes;
166187
metadata.schema = schema;
188+
metadata.variableArity = signature.variableArity();
167189
metadata.routineCharacteristics = getRoutineCharacteristics(name, kind);
190+
metadata.typeVariableConstraints =
191+
std::make_shared<std::vector<protocol::TypeVariableConstraint>>(
192+
getTypeVariableConstraints(signature));
168193

169194
if (aggregateSignature) {
170195
metadata.aggregateMetadata =
@@ -199,8 +224,22 @@ json buildAggregateMetadata(
199224
getWindowFunctionSignatures(name).has_value(),
200225
"Aggregate function {} not registered as a window function",
201226
name);
227+
228+
// The functions returned by this endpoint are stored as SqlInvokedFunction
229+
// objects, with SqlFunctionId serving as the primary key. SqlFunctionId is
230+
// derived from both the functionName and argumentTypes parameters. Returning
231+
// the same function twice—once as an aggregate function and once as a window
232+
// function introduces ambiguity, as functionKind is not a component of
233+
// SqlFunctionId. For any aggregate function utilized as a window function,
234+
// the function’s metadata can be obtained from the associated aggregate
235+
// function implementation for further processing. For additional information,
236+
// refer to the following: •
237+
// https://github.com/prestodb/presto/blob/master/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlFunctionId.java
238+
//
239+
// https://github.com/prestodb/presto/blob/master/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlInvokedFunction.java
240+
202241
const std::vector<protocol::FunctionKind> kinds = {
203-
protocol::FunctionKind::AGGREGATE, protocol::FunctionKind::WINDOW};
242+
protocol::FunctionKind::AGGREGATE};
204243
json j = json::array();
205244
json tj;
206245
for (const auto& kind : kinds) {

presto-native-execution/presto_cpp/main/types/tests/FunctionMetadataTest.cpp

+9-5
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class FunctionMetadataTest : public ::testing::Test {
6161
};
6262

6363
TEST_F(FunctionMetadataTest, approxMostFrequent) {
64-
testFunction("approx_most_frequent", "ApproxMostFrequent.json", 12);
64+
testFunction("approx_most_frequent", "ApproxMostFrequent.json", 6);
6565
}
6666

6767
TEST_F(FunctionMetadataTest, arrayFrequency) {
@@ -73,13 +73,17 @@ TEST_F(FunctionMetadataTest, combinations) {
7373
}
7474

7575
TEST_F(FunctionMetadataTest, covarSamp) {
76-
testFunction("covar_samp", "CovarSamp.json", 4);
76+
testFunction("covar_samp", "CovarSamp.json", 2);
7777
}
7878

7979
TEST_F(FunctionMetadataTest, elementAt) {
8080
testFunction("element_at", "ElementAt.json", 3);
8181
}
8282

83+
TEST_F(FunctionMetadataTest, greatest) {
84+
testFunction("greatest", "Greatest.json", 13);
85+
}
86+
8387
TEST_F(FunctionMetadataTest, lead) {
8488
testFunction("lead", "Lead.json", 3);
8589
}
@@ -89,17 +93,17 @@ TEST_F(FunctionMetadataTest, ntile) {
8993
}
9094

9195
TEST_F(FunctionMetadataTest, setAgg) {
92-
testFunction("set_agg", "SetAgg.json", 2);
96+
testFunction("set_agg", "SetAgg.json", 1);
9397
}
9498

9599
TEST_F(FunctionMetadataTest, stddevSamp) {
96-
testFunction("stddev_samp", "StddevSamp.json", 10);
100+
testFunction("stddev_samp", "StddevSamp.json", 5);
97101
}
98102

99103
TEST_F(FunctionMetadataTest, transformKeys) {
100104
testFunction("transform_keys", "TransformKeys.json", 1);
101105
}
102106

103107
TEST_F(FunctionMetadataTest, variance) {
104-
testFunction("variance", "Variance.json", 10);
108+
testFunction("variance", "Variance.json", 5);
105109
}

0 commit comments

Comments
 (0)