Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Spark function str_to_map when map key duplicate Policy is LAST_WIN #12317

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions velox/docs/functions/spark/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,17 @@ String Functions
SELECT str_to_map('', ',', ':'); -- {"":NULL}
SELECT str_to_map('a:1,b:2,c:3', ',', ','); -- {"a:1":NULL,"b:2":NULL,"c:3":NULL}

.. spark:function:: str_to_map_last_win(string, entryDelimiter, keyValueDelimiter) -> map(string, string)
Similar to ``str_to_map``. However, when duplicate map keys are found for single row's result,
the map key that is inserted at last takes precedence. This is consistent when Spark setting conf
``spark.sql.mapKeyDedupPolicy=LAST_WIN``. ::

SELECT str_to_map_last_win('a:1,b:2,a:3', ',', ':'); -- {"a":"3","b":"2"}
SELECT str_to_map_last_win('a', ',', ':'); -- {"a":NULL}
SELECT str_to_map_last_win('', ',', ':'); -- {"":NULL}
SELECT str_to_map_last_win('a:1,b:2,c:3', ',', ','); -- {"a:1":NULL,"b:2":NULL,"c:3":NULL}

.. spark:function:: substring(string, start) -> varchar
Returns the rest of ``string`` from the starting position ``start``.
Expand Down
68 changes: 68 additions & 0 deletions velox/functions/sparksql/StringToMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,72 @@ struct StringToMapFunction {
}
};

template <typename T>
struct StringToMapLastWinFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

// Results refer to strings in the first argument.
static constexpr int32_t reuse_strings_from_arg = 0;

void call(
out_type<Map<Varchar, Varchar>>& out,
const arg_type<Varchar>& input,
const arg_type<Varchar>& entryDelimiter,
const arg_type<Varchar>& keyValueDelimiter) {
VELOX_USER_CHECK_EQ(
entryDelimiter.size(), 1, "entryDelimiter's size should be 1.");
VELOX_USER_CHECK_EQ(
keyValueDelimiter.size(), 1, "keyValueDelimiter's size should be 1.");

callImpl(
out,
toStringView(input),
toStringView(entryDelimiter),
toStringView(keyValueDelimiter));
}

private:
static std::string_view toStringView(const arg_type<Varchar>& input) {
return std::string_view(input.data(), input.size());
}

void callImpl(
out_type<Map<Varchar, Varchar>>& out,
std::string_view input,
std::string_view entryDelimiter,
std::string_view keyValueDelimiter) const {
folly::F14FastSet<std::string_view> keys;
char pairDelim = entryDelimiter[0];
char keyValueDelim = keyValueDelimiter[0];

int right = input.size(), left;

while (right >= 0) {
left = right - 1;
int firstKeyValueDelimPos = right;
while (left >= 0 && input[left] != pairDelim) {
if (input[left] == keyValueDelim) {
firstKeyValueDelimPos = left;
}
left--;
}

auto key = input.substr(left + 1, firstKeyValueDelimPos - left - 1);
if (!keys.contains(key)) {
if (firstKeyValueDelimPos != right) {
auto [keyWriter, valueWriter] = out.add_item();
keyWriter.setNoCopy(StringView(key));
valueWriter.setNoCopy(StringView(input.substr(
firstKeyValueDelimPos + 1, right - firstKeyValueDelimPos - 1)));
} else {
auto& keyWriter = out.add_null();
keyWriter.setNoCopy(StringView(key));
}
keys.insert(key);
}
right = left;
}
}
};

} // namespace facebook::velox::functions::sparksql
6 changes: 6 additions & 0 deletions velox/functions/sparksql/registration/RegisterString.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ void registerStringFunctions(const std::string& prefix) {
Varchar,
Varchar,
Varchar>({prefix + "str_to_map"});
registerFunction<
sparksql::StringToMapLastWinFunction,
Map<Varchar, Varchar>,
Varchar,
Varchar,
Varchar>({prefix + "str_to_map_last_win"});
registerFunction<sparksql::LeftFunction, Varchar, Varchar, int32_t>(
{prefix + "left"});
registerFunction<sparksql::BitLengthFunction, int32_t, Varchar>(
Expand Down
87 changes: 86 additions & 1 deletion velox/functions/sparksql/tests/StringToMapTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,23 @@ class StringToMapTest : public SparkFunctionBaseTest {
expect) {
auto result = evaluateStringToMap(inputs);
auto expectVector = makeMapVector<StringView, StringView>({expect});
assertEqualVectors(result, expectVector);
assertEqualVectors(expectVector, result);
}

VectorPtr evaluateStringToMapLastWin(const std::vector<StringView>& inputs) {
const std::string expr = fmt::format(
"str_to_map_last_win(c0, '{}', '{}')", inputs[1], inputs[2]);
return evaluate<MapVector>(
expr, makeRowVector({makeFlatVector<StringView>({inputs[0]})}));
}

void testStringToMapLastWin(
const std::vector<StringView>& inputs,
const std::vector<std::pair<StringView, std::optional<StringView>>>&
expect) {
auto result = evaluateStringToMapLastWin(inputs);
auto expectVector = makeMapVector<StringView, StringView>({expect});
assertEqualVectors(expectVector, result);
}
};

Expand Down Expand Up @@ -94,5 +110,74 @@ TEST_F(StringToMapTest, basic) {
evaluateStringToMap({":1,:2", ",", ":"}),
"Duplicate keys are not allowed: ''.");
}

TEST_F(StringToMapTest, basicLastWinWithoutDuplicateKey) {
testStringToMapLastWin(
{"a:1,b:2,c:3", ",", ":"}, {{"a", "1"}, {"b", "2"}, {"c", "3"}});
testStringToMapLastWin({"a: ,b:2", ",", ":"}, {{"a", " "}, {"b", "2"}});
testStringToMapLastWin({"a:,b:2", ",", ":"}, {{"a", ""}, {"b", "2"}});
testStringToMapLastWin({"", ",", ":"}, {{"", std::nullopt}});
testStringToMapLastWin({"a", ",", ":"}, {{"a", std::nullopt}});
testStringToMapLastWin(
{"a=1,b=2,c=3", ",", "="}, {{"a", "1"}, {"b", "2"}, {"c", "3"}});
testStringToMapLastWin({"", ",", "="}, {{"", std::nullopt}});
testStringToMapLastWin(
{"a::1,b::2,c::3", ",", "c"},
{{"a::1", std::nullopt}, {"b::2", std::nullopt}, {"", "::3"}});
testStringToMapLastWin(
{"a:1_b:2_c:3", "_", ":"}, {{"a", "1"}, {"b", "2"}, {"c", "3"}});

// Same delimiters.
testStringToMapLastWin(
{"a:1,b:2,c:3", ",", ","},
{{"a:1", std::nullopt}, {"b:2", std::nullopt}, {"c:3", std::nullopt}});
testStringToMapLastWin(
{"a:1_b:2_c:3", "_", "_"},
{{"a:1", std::nullopt}, {"b:2", std::nullopt}, {"c:3", std::nullopt}});

testStringToMapLastWin(
{"a:1;b:2;c:333;jack:1000;b:10;c:20", ";", ":"},
{{"a", "1"}, {"b", "10"}, {"c", "20"}, {"jack", "1000"}});

testStringToMapLastWin(
{"a:1;b:2;c:333;jack:1000;b:10;c:20;", ";", ":"},
{{"", std::nullopt},
{"a", "1"},
{"b", "10"},
{"c", "20"},
{"jack", "1000"}});

testStringToMapLastWin(
{"a:::1;b:2;c:333;jack:10:00;b:10;c:20;:;;", ";", ":"},
{{"", std::nullopt},
{"a", "::1"},
{"b", "10"},
{"c", "20"},
{"jack", "10:00"}});

testStringToMapLastWin(
{"a:::1;b:2;c:333;jack:10:00;b:10;c:20;:;;", ";", ":"},
{{"", std::nullopt},
{"a", "::1"},
{"b", "10"},
{"c", "20"},
{"jack", "10:00"}});

testStringToMapLastWin(
{"a:::1;b:2;c:333;jack:10:00;b:10;c:20;:", ";", ":"},
{{"", ""}, {"a", "::1"}, {"b", "10"}, {"c", "20"}, {"jack", "10:00"}});

testStringToMapLastWin({"", ";", ":"}, {{"", std::nullopt}});

testStringToMapLastWin({";", ";", ":"}, {{"", std::nullopt}});

testStringToMapLastWin({":", ";", ":"}, {{"", ""}});

testStringToMapLastWin({"::::", ";", ":"}, {{"", ":::"}});

testStringToMapLastWin(
{"jack;rose", ";", ":"},
{{"jack", std::nullopt}, {"rose", std::nullopt}});
}
} // namespace
} // namespace facebook::velox::functions::sparksql::test