diff --git a/velox/expression/VectorFunction.cpp b/velox/expression/VectorFunction.cpp index 69046f496c87..44e70cd44977 100644 --- a/velox/expression/VectorFunction.cpp +++ b/velox/expression/VectorFunction.cpp @@ -62,6 +62,13 @@ VectorFunctionMap& vectorFunctionFactories() { return factories; } +std::optional getVectorFunctionMetadata( + const std::string& name) { + return applyToVectorFunctionEntry( + name, + [&](const auto& /*name*/, const auto& entry) { return entry.metadata; }); +} + std::optional> getVectorFunctionSignatures( const std::string& name) { return applyToVectorFunctionEntry>( diff --git a/velox/expression/VectorFunction.h b/velox/expression/VectorFunction.h index 1c1c467bae8e..17e726148407 100644 --- a/velox/expression/VectorFunction.h +++ b/velox/expression/VectorFunction.h @@ -170,6 +170,11 @@ class SimpleFunctionAdapterFactory { virtual ~SimpleFunctionAdapterFactory() = default; }; +/// Returns the function metadata with the specified name. Returns std::nullopt +/// if there is no function with the specified name. +std::optional getVectorFunctionMetadata( + const std::string& name); + /// Returns a list of signatures supported by VectorFunction with the specified /// name. Returns std::nullopt if there is no function with the specified name. std::optional> getVectorFunctionSignatures( diff --git a/velox/expression/fuzzer/ExpressionFuzzer.cpp b/velox/expression/fuzzer/ExpressionFuzzer.cpp index 4817d71ac462..b3cf1458a840 100644 --- a/velox/expression/fuzzer/ExpressionFuzzer.cpp +++ b/velox/expression/fuzzer/ExpressionFuzzer.cpp @@ -341,12 +341,9 @@ static void appendSpecialForms( } } -/// Returns if `functionName` with the given `argTypes` is deterministic. -/// Returns true if the function was not found or determinism cannot be -/// established. -bool isDeterministic( - const std::string& functionName, - const std::vector& argTypes) { +// Returns if `functionName` is deterministic. Returns true if the function was +// not found or determinism cannot be established. +bool isDeterministic(const std::string& functionName) { // We know that the 'cast', 'and', and 'or' special forms are deterministic. // Hard-code them here because they are not real functions and hence cannot // be resolved by the code below. @@ -356,15 +353,14 @@ bool isDeterministic( return true; } - if (auto typeAndMetadata = - resolveFunctionWithMetadata(functionName, argTypes)) { - return typeAndMetadata->second.deterministic; + const auto determinism = velox::isDeterministic(functionName); + if (!determinism.has_value()) { + // functionName must be a special form. + LOG(WARNING) << "Unable to determine if '" << functionName + << "' is deterministic or not. Assuming it is."; + return true; } - - // functionName must be a special form. - LOG(WARNING) << "Unable to determine if '" << functionName - << "' is deterministic or not. Assuming it is."; - return true; + return determinism.value(); } std::optional processConcreteSignature( @@ -567,13 +563,26 @@ ExpressionFuzzer::ExpressionFuzzer( continue; } - // Determine a list of concrete argument types that can bind to the - // signature. For non-parameterized signatures, these argument types will - // be used to create a callable signature. For parameterized signatures, - // these argument types are only used to fetch the function instance to - // get their determinism. - std::vector argTypes; - if (signature->variables().empty()) { + if (!isDeterministic(function.first)) { + LOG(WARNING) << "Skipping non-deterministic function: " + << function.first << signature->toString(); + continue; + } + + if (!signature->variables().empty()) { + std::unordered_set typeVariables; + for (const auto& [name, _] : signature->variables()) { + typeVariables.insert(name); + } + atLeastOneSupported = true; + ++supportedFunctionSignatures; + signatureTemplates_.emplace_back(SignatureTemplate{ + function.first, signature, std::move(typeVariables)}); + } else { + // Determine a list of concrete argument types that can bind to the + // signature. For non-parameterized signatures, these argument types + // will be used to create a callable signature. + std::vector argTypes; bool supportedSignature = true; for (const auto& arg : signature->argumentTypes()) { auto resolvedType = SignatureBinder::tryResolveType(arg, {}, {}); @@ -589,37 +598,15 @@ ExpressionFuzzer::ExpressionFuzzer( << function.first << signature->toString(); continue; } - } else { - ArgumentTypeFuzzer typeFuzzer{*signature, localRng}; - typeFuzzer.fuzzReturnType(); - VELOX_CHECK_EQ( - typeFuzzer.fuzzArgumentTypes(options_.maxNumVarArgs), true); - argTypes = typeFuzzer.argumentTypes(); - } - if (!isDeterministic(function.first, argTypes)) { - LOG(WARNING) << "Skipping non-deterministic function: " - << function.first << signature->toString(); - continue; - } - - if (!signature->variables().empty()) { - std::unordered_set typeVariables; - for (const auto& [name, _] : signature->variables()) { - typeVariables.insert(name); + if (auto callableFunction = processConcreteSignature( + function.first, + argTypes, + *signature, + options_.enableComplexTypes)) { + atLeastOneSupported = true; + ++supportedFunctionSignatures; + signatures_.emplace_back(*callableFunction); } - atLeastOneSupported = true; - ++supportedFunctionSignatures; - signatureTemplates_.emplace_back(SignatureTemplate{ - function.first, signature, std::move(typeVariables)}); - } else if ( - auto callableFunction = processConcreteSignature( - function.first, - argTypes, - *signature, - options_.enableComplexTypes)) { - atLeastOneSupported = true; - ++supportedFunctionSignatures; - signatures_.emplace_back(*callableFunction); } } diff --git a/velox/functions/FunctionRegistry.cpp b/velox/functions/FunctionRegistry.cpp index 0c07f8a0b9bb..5ee951c80c1e 100644 --- a/velox/functions/FunctionRegistry.cpp +++ b/velox/functions/FunctionRegistry.cpp @@ -76,6 +76,25 @@ void clearFunctionRegistry() { [](auto& functionMap) { functionMap.clear(); }); } +std::optional isDeterministic(const std::string& functionName) { + const auto simpleFunctions = + exec::simpleFunctions().getFunctionSignaturesAndMetadata(functionName); + const auto metadata = exec::getVectorFunctionMetadata(functionName); + if (simpleFunctions.empty() && !metadata.has_value()) { + return std::nullopt; + } + + for (const auto& [metadata, _] : simpleFunctions) { + if (!metadata.deterministic) { + return false; + } + } + if (metadata.has_value() && !metadata.value().deterministic) { + return false; + } + return true; +} + TypePtr resolveFunction( const std::string& functionName, const std::vector& argTypes) { diff --git a/velox/functions/FunctionRegistry.h b/velox/functions/FunctionRegistry.h index 9f51a92b270a..04b86cfb49d5 100644 --- a/velox/functions/FunctionRegistry.h +++ b/velox/functions/FunctionRegistry.h @@ -35,6 +35,12 @@ FunctionSignatureMap getFunctionSignatures(); /// The mapping is function name -> list of function signatures FunctionSignatureMap getVectorFunctionSignatures(); +/// Returns if a function is deterministic by fetching all registry entries for +/// the given function name and checking if all of them are deterministic. +/// Returns std::nullopt if the function is not found. Returns false if any of +/// the entries are not deterministic. +std::optional isDeterministic(const std::string& functionName); + /// Given a function name and argument types, returns /// the return type if function exists otherwise returns nullptr TypePtr resolveFunction( diff --git a/velox/functions/tests/FunctionRegistryTest.cpp b/velox/functions/tests/FunctionRegistryTest.cpp index b56e980d45e9..be34bb5bc3b7 100644 --- a/velox/functions/tests/FunctionRegistryTest.cpp +++ b/velox/functions/tests/FunctionRegistryTest.cpp @@ -24,6 +24,7 @@ #include "velox/functions/FunctionRegistry.h" #include "velox/functions/Macros.h" #include "velox/functions/Registerer.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" #include "velox/type/Type.h" @@ -495,6 +496,20 @@ TEST_F(FunctionRegistryTest, functionNameInMixedCase) { ASSERT_EQ(*result, *VARCHAR()); } +TEST_F(FunctionRegistryTest, isDeterministic) { + functions::prestosql::registerAllScalarFunctions(); + ASSERT_TRUE(isDeterministic("plus").value()); + ASSERT_TRUE(isDeterministic("in").value()); + + ASSERT_FALSE(isDeterministic("rand").value()); + ASSERT_FALSE(isDeterministic("uuid").value()); + ASSERT_FALSE(isDeterministic("shuffle").value()); + + // Not found functions. + ASSERT_FALSE(isDeterministic("cast").has_value()); + ASSERT_FALSE(isDeterministic("not_found_function").has_value()); +} + template struct TestFunction { VELOX_DEFINE_FUNCTION_TYPES(T);