From 6a2b5f002943e09f6ef50eb02a674ff4abc7b32d Mon Sep 17 00:00:00 2001 From: Archon Date: Mon, 16 May 2022 17:02:44 +0800 Subject: [PATCH] PrestoSQL migrate to Trino --- .gitignore | 29 +++++ README.md | 29 +++++ pom.xml | 76 +++++++++++++ .../com/github/archongum/trino/UdfPlugin.java | 45 ++++++++ .../ArrayAggDistinctAggregation.java | 69 ++++++++++++ .../ArrayAggDistinctIntegerAggregation.java | 69 ++++++++++++ .../aggregate/MaxCountElementAggregation.java | 71 ++++++++++++ .../trino/udf/aggregate/state/MapState.java | 25 +++++ .../udf/aggregate/state/MapStateFactory.java | 77 +++++++++++++ .../aggregate/state/MapStateSerializer.java | 52 +++++++++ .../trino/udf/aggregate/state/SetState.java | 26 +++++ .../udf/aggregate/state/SetStateFactory.java | 80 ++++++++++++++ .../udf/aggregate/state/SetStateLong.java | 25 +++++ .../aggregate/state/SetStateLongFactory.java | 80 ++++++++++++++ .../state/SetStateLongSerializer.java | 50 +++++++++ .../aggregate/state/SetStateSerializer.java | 50 +++++++++ .../scalar/ArrayMaxCountElementFunction.java | 103 ++++++++++++++++++ .../trino/udf/scalar/CommonFunctions.java | 25 +++++ .../trino/udf/scalar/DateTimeFunctions.java | 66 +++++++++++ .../trino/udf/scalar/DateTimeUtils.java | 102 +++++++++++++++++ .../META-INF/services/io.trino.spi.Plugin | 1 + .../udf/aggregate/TestAggreateFunctions.java | 39 +++++++ .../udf/scalar/TestDateTimeFunctions.java | 39 +++++++ unit_test.sql | 47 ++++++++ 24 files changed, 1275 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 pom.xml create mode 100644 src/main/java/com/github/archongum/trino/UdfPlugin.java create mode 100644 src/main/java/com/github/archongum/trino/udf/aggregate/ArrayAggDistinctAggregation.java create mode 100644 src/main/java/com/github/archongum/trino/udf/aggregate/ArrayAggDistinctIntegerAggregation.java create mode 100644 src/main/java/com/github/archongum/trino/udf/aggregate/MaxCountElementAggregation.java create mode 100644 src/main/java/com/github/archongum/trino/udf/aggregate/state/MapState.java create mode 100644 src/main/java/com/github/archongum/trino/udf/aggregate/state/MapStateFactory.java create mode 100644 src/main/java/com/github/archongum/trino/udf/aggregate/state/MapStateSerializer.java create mode 100644 src/main/java/com/github/archongum/trino/udf/aggregate/state/SetState.java create mode 100644 src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateFactory.java create mode 100644 src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateLong.java create mode 100644 src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateLongFactory.java create mode 100644 src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateLongSerializer.java create mode 100644 src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateSerializer.java create mode 100644 src/main/java/com/github/archongum/trino/udf/scalar/ArrayMaxCountElementFunction.java create mode 100644 src/main/java/com/github/archongum/trino/udf/scalar/CommonFunctions.java create mode 100644 src/main/java/com/github/archongum/trino/udf/scalar/DateTimeFunctions.java create mode 100644 src/main/java/com/github/archongum/trino/udf/scalar/DateTimeUtils.java create mode 100644 src/main/resources/META-INF/services/io.trino.spi.Plugin create mode 100644 src/test/com/github/archongum/trino/udf/aggregate/TestAggreateFunctions.java create mode 100644 src/test/com/github/archongum/trino/udf/scalar/TestDateTimeFunctions.java create mode 100644 unit_test.sql diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bc00705 --- /dev/null +++ b/.gitignore @@ -0,0 +1,29 @@ +.idea +*.iml + +# Compiled class file +*.class + +# Log file +*.log + +# BlueJ files +*.ctxt + +# Mobile Tools for Java (J2ME) +.mtj.tmp/ + +# Package Files # +*.jar +*.war +*.ear +*.zip +*.tar.gz +*.rar + +# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml +hs_err_pid* + +# idea +target/ +out/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..70303ae --- /dev/null +++ b/README.md @@ -0,0 +1,29 @@ +# Installation +1. `mvn clean assembly:assembly` +2. Copy `trino-udf-*-jar-with-dependencies.jar` to `${TRINO_HOME}/plugin/custom-functions/` in all Trino nodes. +(create directory if not exists) +3. Restart Trino cluster + +# Versions +- JDK-11 +- Trino-380 + +# Functions +## Scalar Functions +| Function | Return Type | Argument Types | Description | Usage | +|-------------------------|-------------|----------------|--------------------------------------------------------------------------------------|-----------------------------------------| +| first_day | date | date | first day of month | first_day(current_date) | +| last_day | date | date | last day of month | last_day(current_date) | +| yesterday | date | | yesterday | yesterday() | +| last_second | timestamp | date | last second of the date | last_second(current_date) | +| yesterday_last_second | timestamp | | last second of yesterday | yesterday_last_second() | +| to_datetime | timestamp | date, varchar | combine the two args | to_datetime(current_date, '23:59:59') | +| array_max_count_element | T | array(T) | Get maximum count element (null is not counting; if has multiple return one of them) | array_max_count_element(array['1','2']) | +| rand | double | varchar | Return double in [0,1] | rand(varchar) | + +## Aggregate Functions +| Function | Return Type | Argument Types | Description | Usage | +|----------------------------| ----------- |----------------| ------------------------------------------------------------------------------------ | ----------------------- | +| max_count_element | VARCHAR | VARCHAR | Get maximum count element (null is not counting; if has multiple return one of them) | max_count_element(name) | +| array_agg_distinct | INTEGER | array(VARCHAR) | Count distinct array elements. input: array(VARCHAR), output: integer. | array_agg_distinct(ids) | +| array_agg_distinct_integer | INTEGER | array(INTEGER) | Count distinct array elements. input: array(INTEGER), output: integer. | array_agg_distinct(ids) | diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000..0f48807 --- /dev/null +++ b/pom.xml @@ -0,0 +1,76 @@ + + + 4.0.0 + + com.github.archongum + trino-udf + 4 + + + 380 + 11 + 11 + 3.8.0 + 2.13.1 + 5.4.2 + + + + + + io.trino + trino-spi + ${trino.version} + provided + + + + io.trino + trino-array + ${trino.version} + + + it.unimi.dsi + fastutil + + + + + + com.fasterxml.jackson.core + jackson-databind + ${jackson.version} + + + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + ${maven.compiler.version} + + utf-8 + + + + + maven-assembly-plugin + + + jar-with-dependencies + + + + + + diff --git a/src/main/java/com/github/archongum/trino/UdfPlugin.java b/src/main/java/com/github/archongum/trino/UdfPlugin.java new file mode 100644 index 0000000..459131f --- /dev/null +++ b/src/main/java/com/github/archongum/trino/UdfPlugin.java @@ -0,0 +1,45 @@ +/* + * Copyright 2013-2016 Qubole + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.github.archongum.trino; + +import java.util.HashSet; +import java.util.Set; +import com.github.archongum.trino.udf.aggregate.ArrayAggDistinctAggregation; +import com.github.archongum.trino.udf.aggregate.ArrayAggDistinctIntegerAggregation; +import com.github.archongum.trino.udf.aggregate.MaxCountElementAggregation; +import com.github.archongum.trino.udf.scalar.ArrayMaxCountElementFunction; +import com.github.archongum.trino.udf.scalar.CommonFunctions; +import com.github.archongum.trino.udf.scalar.DateTimeFunctions; +import io.trino.spi.Plugin; + + +/** + * @author Archon 2018年9月20日 + */ +public class UdfPlugin implements Plugin { + @Override + public Set> getFunctions() + { + Set> set = new HashSet<>(); + set.add(ArrayMaxCountElementFunction.class); + set.add(CommonFunctions.class); + set.add(DateTimeFunctions.class); + set.add(MaxCountElementAggregation.class); + set.add(ArrayAggDistinctAggregation.class); + set.add(ArrayAggDistinctIntegerAggregation.class); + return set; + } +} diff --git a/src/main/java/com/github/archongum/trino/udf/aggregate/ArrayAggDistinctAggregation.java b/src/main/java/com/github/archongum/trino/udf/aggregate/ArrayAggDistinctAggregation.java new file mode 100644 index 0000000..97eec00 --- /dev/null +++ b/src/main/java/com/github/archongum/trino/udf/aggregate/ArrayAggDistinctAggregation.java @@ -0,0 +1,69 @@ +package com.github.archongum.trino.udf.aggregate; + +import java.util.HashSet; +import java.util.Set; +import com.github.archongum.trino.udf.aggregate.state.SetState; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Description; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.type.StandardTypes; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.VarcharType.VARCHAR; + + +/** + * @author Archon 2021年10月21日 + */ +@AggregationFunction("array_agg_distinct") +@Description("Count distinct array elements. input: array(VARCHAR), output: integer.") +public class ArrayAggDistinctAggregation { + + @InputFunction + public static void input(SetState state, @SqlType("array(VARCHAR)") Block block) { + if (block.getPositionCount() == 0) { + return; + } + + Set set = state.getSet(); + if (set == null) { + set = new HashSet<>(); + state.setSet(set); + } + + for (int i = 0; i < block.getPositionCount(); i++) { + if (block.isNull(i)) { + continue; + } + String curElement = VARCHAR.getSlice(block, i).toStringUtf8(); + set.add(curElement); + } + } + + @CombineFunction + public static void combine(SetState state, SetState otherState) { + Set prev = state.getSet(); + Set input = otherState.getSet(); + if (prev == null) { + state.setSet(input); + } else { + if (input != null && !input.isEmpty()) { + prev.addAll(input); + } + } + } + + @OutputFunction(StandardTypes.INTEGER) + public static void output(SetState state, BlockBuilder out) { + Set set = state.getSet(); + if (set == null || set.isEmpty()) { + out.appendNull(); + } else { + INTEGER.writeLong(out, set.size()); + } + } +} diff --git a/src/main/java/com/github/archongum/trino/udf/aggregate/ArrayAggDistinctIntegerAggregation.java b/src/main/java/com/github/archongum/trino/udf/aggregate/ArrayAggDistinctIntegerAggregation.java new file mode 100644 index 0000000..bae3a6f --- /dev/null +++ b/src/main/java/com/github/archongum/trino/udf/aggregate/ArrayAggDistinctIntegerAggregation.java @@ -0,0 +1,69 @@ +package com.github.archongum.trino.udf.aggregate; + +import java.util.HashSet; +import java.util.Set; +import com.github.archongum.trino.udf.aggregate.state.SetStateLong; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Description; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.type.StandardTypes; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.IntegerType.INTEGER; + + +/** + * @author Archon 2021年10月21日 + */ +@AggregationFunction("array_agg_distinct_integer") +@Description("Count distinct array elements. input: array(BIGINT), output: integer.") +public class ArrayAggDistinctIntegerAggregation { + + @InputFunction + public static void input(SetStateLong state, @SqlType("array(BIGINT)") Block block) { + if (block.getPositionCount() == 0) { + return; + } + + Set set = state.getSet(); + if (set == null) { + set = new HashSet<>(); + state.setSet(set); + } + + for (int i = 0; i < block.getPositionCount(); i++) { + if (block.isNull(i)) { + continue; + } + Long curElement = BIGINT.getLong(block, i); + set.add(curElement); + } + } + + @CombineFunction + public static void combine(SetStateLong state, SetStateLong otherState) { + Set prev = state.getSet(); + Set input = otherState.getSet(); + if (prev == null) { + state.setSet(input); + } else { + if (input != null && !input.isEmpty()) { + prev.addAll(input); + } + } + } + + @OutputFunction(StandardTypes.INTEGER) + public static void output(SetStateLong state, BlockBuilder out) { + Set set = state.getSet(); + if (set == null || set.isEmpty()) { + out.appendNull(); + } else { + INTEGER.writeLong(out, set.size()); + } + } +} diff --git a/src/main/java/com/github/archongum/trino/udf/aggregate/MaxCountElementAggregation.java b/src/main/java/com/github/archongum/trino/udf/aggregate/MaxCountElementAggregation.java new file mode 100644 index 0000000..066bc9a --- /dev/null +++ b/src/main/java/com/github/archongum/trino/udf/aggregate/MaxCountElementAggregation.java @@ -0,0 +1,71 @@ +package com.github.archongum.trino.udf.aggregate; + +import com.github.archongum.trino.udf.aggregate.state.MapState; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Description; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.type.StandardTypes; + +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; + +import static io.trino.spi.type.VarcharType.VARCHAR; + + +/** + * @author Archon 2019年8月30日 + */ +@AggregationFunction("max_count_element") +@Description("Get maximum count element (null is not counting; if has multiple return one of them)") +public class MaxCountElementAggregation { + @InputFunction + public static void input(MapState state, @SqlType(StandardTypes.VARCHAR) Slice value) { + Map map = state.getMap(); + if (map == null) { + map = new HashMap<>(16); + state.setMap(map); + } + String v = value.toStringUtf8(); + Long cnt = map.get(v); + if (cnt == null) { + map.put(v, 1L); + } else { + map.put(v, cnt+1); + } + } + + @CombineFunction + public static void combine(MapState state, MapState otherState) { + if (state.getMap() == null && otherState.getMap() == null) { + return; + } + if (otherState.getMap() == null && state.getMap() != null) { + otherState.setMap(state.getMap()); + return; + } + if (state.getMap() == null && otherState.getMap() != null) { + state.setMap(otherState.getMap()); + return; + } + + otherState.getMap().forEach((k, v) -> state.getMap().merge(k, v, Long::sum)); + } + + @OutputFunction(StandardTypes.VARCHAR) + public static void output(MapState state, BlockBuilder out) + { + if (state.getMap().isEmpty()) { + out.appendNull(); + } else { + VARCHAR.writeSlice(out, + Slices.utf8Slice(state.getMap().entrySet().stream().max(Entry.comparingByValue()).get().getKey())); + } + } +} diff --git a/src/main/java/com/github/archongum/trino/udf/aggregate/state/MapState.java b/src/main/java/com/github/archongum/trino/udf/aggregate/state/MapState.java new file mode 100644 index 0000000..e7d0070 --- /dev/null +++ b/src/main/java/com/github/archongum/trino/udf/aggregate/state/MapState.java @@ -0,0 +1,25 @@ +package com.github.archongum.trino.udf.aggregate.state; + +import java.util.Map; +import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.AccumulatorStateMetadata; + + +/** + * @author Archon 2019年8月30日 + */ +@AccumulatorStateMetadata(stateSerializerClass = MapStateSerializer.class, stateFactoryClass = MapStateFactory.class) +public interface MapState extends AccumulatorState { + + /** + * get map + * @return map + */ + Map getMap(); + + /** + * set map + * @param value map + */ + void setMap(Map value); +} diff --git a/src/main/java/com/github/archongum/trino/udf/aggregate/state/MapStateFactory.java b/src/main/java/com/github/archongum/trino/udf/aggregate/state/MapStateFactory.java new file mode 100644 index 0000000..607de43 --- /dev/null +++ b/src/main/java/com/github/archongum/trino/udf/aggregate/state/MapStateFactory.java @@ -0,0 +1,77 @@ +package com.github.archongum.trino.udf.aggregate.state; + +import java.util.HashMap; +import java.util.Map; +import io.trino.array.ObjectBigArray; +import io.trino.spi.function.AccumulatorStateFactory; +import io.trino.spi.function.GroupedAccumulatorState; + + +/** + * @author Archon 8/30/19 + * @since + */ +public class MapStateFactory implements AccumulatorStateFactory { + + public static final class SingleMapState implements MapState { + + private Map map = new HashMap<>(); + + @Override + public Map getMap() { + return map; + } + + @Override + public void setMap(Map value) { + this.map = value; + } + + @Override + public long getEstimatedSize() { + return map.size(); + } + } + + public static class GroupedMapState implements GroupedAccumulatorState, MapState { + + private final ObjectBigArray> maps = new ObjectBigArray<>(); + private long groupId; + + @Override + public Map getMap() { + return maps.get(groupId); + } + + @Override + public void setMap(Map value) { + maps.set(groupId, value); + } + + @Override + public void setGroupId(long groupId) { + this.groupId = groupId; + } + + @Override + public void ensureCapacity(long size) { + maps.ensureCapacity(size); + } + + @Override + public long getEstimatedSize() { + return maps.sizeOf(); + } + } + + @Override + public MapState createSingleState() { + return new SingleMapState(); + } + + @Override + public MapState createGroupedState() { + return new GroupedMapState(); + } + +} diff --git a/src/main/java/com/github/archongum/trino/udf/aggregate/state/MapStateSerializer.java b/src/main/java/com/github/archongum/trino/udf/aggregate/state/MapStateSerializer.java new file mode 100644 index 0000000..0a27ab7 --- /dev/null +++ b/src/main/java/com/github/archongum/trino/udf/aggregate/state/MapStateSerializer.java @@ -0,0 +1,52 @@ +package com.github.archongum.trino.udf.aggregate.state; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.airlift.slice.Slices; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.type.Type; + +import java.io.IOException; +import java.util.HashMap; + +import static io.trino.spi.type.VarcharType.VARCHAR; + + +/** + * @author Archon 8/30/19 + * @since + */ +public class MapStateSerializer implements AccumulatorStateSerializer +{ + + private final ObjectMapper mapper = new ObjectMapper(); + + @Override + public Type getSerializedType() { + return VARCHAR; + } + + @Override + public void serialize(MapState state, BlockBuilder out) { + try { + String jsonResult = mapper.writeValueAsString(state.getMap()); + VARCHAR.writeSlice(out, Slices.utf8Slice(jsonResult)); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + @Override + public void deserialize(Block block, int index, MapState state) { + try { + TypeReference> typeRef = new TypeReference<>() {}; + state.setMap(mapper.readValue(VARCHAR.getSlice(block, index).toStringUtf8(), typeRef)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} + diff --git a/src/main/java/com/github/archongum/trino/udf/aggregate/state/SetState.java b/src/main/java/com/github/archongum/trino/udf/aggregate/state/SetState.java new file mode 100644 index 0000000..6bf8991 --- /dev/null +++ b/src/main/java/com/github/archongum/trino/udf/aggregate/state/SetState.java @@ -0,0 +1,26 @@ +package com.github.archongum.trino.udf.aggregate.state; + +import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.AccumulatorStateMetadata; + +import java.util.Set; + + +/** + * @author Archon 2021年10月21日 + */ +@AccumulatorStateMetadata(stateSerializerClass = SetStateSerializer.class, stateFactoryClass = SetStateFactory.class) +public interface SetState extends AccumulatorState { + + /** + * get set + * @return set + */ + Set getSet(); + + /** + * set set + * @param value set + */ + void setSet(Set value); +} diff --git a/src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateFactory.java b/src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateFactory.java new file mode 100644 index 0000000..e155acf --- /dev/null +++ b/src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateFactory.java @@ -0,0 +1,80 @@ +package com.github.archongum.trino.udf.aggregate.state; + +import java.util.HashSet; +import java.util.Set; +import io.trino.array.ObjectBigArray; +import io.trino.spi.function.AccumulatorStateFactory; +import io.trino.spi.function.GroupedAccumulatorState; + + +/** + * @author Archon 8/30/19 + * @since + */ +public class SetStateFactory implements AccumulatorStateFactory { + + public static final class SingleSetState implements SetState { + + private Set set = new HashSet<>(); + + @Override + public Set getSet() { + return set; + } + + @Override + public void setSet(Set value) { + this.set = value; + } + + @Override + public long getEstimatedSize() { + return set.size(); + } + } + + public static class GroupedSetState implements GroupedAccumulatorState, SetState { + + private final ObjectBigArray> container = new ObjectBigArray<>(); + private long groupId; + + @Override + public Set getSet() { + return container.get(groupId); + } + + @Override + public void setSet(Set value) { + container.set(groupId, value); + } + + @Override + public void setGroupId(long groupId) { + this.groupId = groupId; + if (this.getSet() == null) { + this.setSet(new HashSet<>()); + } + } + + @Override + public void ensureCapacity(long size) { + container.ensureCapacity(size); + } + + @Override + public long getEstimatedSize() { + return container.sizeOf(); + } + } + + @Override + public SetState createSingleState() { + return new SingleSetState(); + } + + + @Override + public SetState createGroupedState() { + return new GroupedSetState(); + } +} diff --git a/src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateLong.java b/src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateLong.java new file mode 100644 index 0000000..41cbd7c --- /dev/null +++ b/src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateLong.java @@ -0,0 +1,25 @@ +package com.github.archongum.trino.udf.aggregate.state; + +import java.util.Set; +import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.AccumulatorStateMetadata; + + +/** + * @author Archon 2021年10月21日 + */ +@AccumulatorStateMetadata(stateSerializerClass = SetStateLongSerializer.class, stateFactoryClass = SetStateLongFactory.class) +public interface SetStateLong extends AccumulatorState { + + /** + * get set + * @return set + */ + Set getSet(); + + /** + * set set + * @param value set + */ + void setSet(Set value); +} diff --git a/src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateLongFactory.java b/src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateLongFactory.java new file mode 100644 index 0000000..22015d2 --- /dev/null +++ b/src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateLongFactory.java @@ -0,0 +1,80 @@ +package com.github.archongum.trino.udf.aggregate.state; + +import java.util.HashSet; +import java.util.Set; +import io.trino.array.ObjectBigArray; +import io.trino.spi.function.AccumulatorStateFactory; +import io.trino.spi.function.GroupedAccumulatorState; + + +/** + * @author Archon 8/30/19 + * @since + */ +public class SetStateLongFactory implements AccumulatorStateFactory { + + public static final class SingleSetState implements SetStateLong { + + private Set set = new HashSet<>(); + + @Override + public Set getSet() { + return set; + } + + @Override + public void setSet(Set value) { + this.set = value; + } + + @Override + public long getEstimatedSize() { + return set.size(); + } + } + + public static class GroupedSetState implements GroupedAccumulatorState, SetStateLong { + + private final ObjectBigArray> container = new ObjectBigArray<>(); + private long groupId; + + @Override + public Set getSet() { + return container.get(groupId); + } + + @Override + public void setSet(Set value) { + container.set(groupId, value); + } + + @Override + public void setGroupId(long groupId) { + this.groupId = groupId; + if (this.getSet() == null) { + this.setSet(new HashSet<>()); + } + } + + @Override + public void ensureCapacity(long size) { + container.ensureCapacity(size); + } + + @Override + public long getEstimatedSize() { + return container.sizeOf(); + } + } + + @Override + public SetStateLong createSingleState() { + return new SingleSetState(); + } + + + @Override + public SetStateLong createGroupedState() { + return new GroupedSetState(); + } +} diff --git a/src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateLongSerializer.java b/src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateLongSerializer.java new file mode 100644 index 0000000..9199405 --- /dev/null +++ b/src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateLongSerializer.java @@ -0,0 +1,50 @@ +package com.github.archongum.trino.udf.aggregate.state; + +import java.io.IOException; +import java.util.HashSet; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.airlift.slice.Slices; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.type.Type; +import static io.trino.spi.type.VarcharType.VARCHAR; + + +/** + * @author Archon 8/30/19 + * @since + */ +public class SetStateLongSerializer implements AccumulatorStateSerializer +{ + + private final ObjectMapper mapper = new ObjectMapper(); + + @Override + public Type getSerializedType() { + return VARCHAR; + } + + @Override + public void serialize(SetStateLong state, BlockBuilder out) { + try { + String jsonResult = mapper.writeValueAsString(state.getSet()); + VARCHAR.writeSlice(out, Slices.utf8Slice(jsonResult)); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + @Override + public void deserialize(Block block, int index, SetStateLong state) { + try { + TypeReference> typeRef = new TypeReference<>() {}; + state.setSet(mapper.readValue(VARCHAR.getSlice(block, index).toStringUtf8(), typeRef)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} + diff --git a/src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateSerializer.java b/src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateSerializer.java new file mode 100644 index 0000000..71e96dd --- /dev/null +++ b/src/main/java/com/github/archongum/trino/udf/aggregate/state/SetStateSerializer.java @@ -0,0 +1,50 @@ +package com.github.archongum.trino.udf.aggregate.state; + +import java.io.IOException; +import java.util.HashSet; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.airlift.slice.Slices; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.type.Type; +import static io.trino.spi.type.VarcharType.VARCHAR; + + +/** + * @author Archon 8/30/19 + * @since + */ +public class SetStateSerializer implements AccumulatorStateSerializer +{ + + private final ObjectMapper mapper = new ObjectMapper(); + + @Override + public Type getSerializedType() { + return VARCHAR; + } + + @Override + public void serialize(SetState state, BlockBuilder out) { + try { + String jsonResult = mapper.writeValueAsString(state.getSet()); + VARCHAR.writeSlice(out, Slices.utf8Slice(jsonResult)); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + @Override + public void deserialize(Block block, int index, SetState state) { + try { + TypeReference> typeRef = new TypeReference<>() {}; + state.setSet(mapper.readValue(VARCHAR.getSlice(block, index).toStringUtf8(), typeRef)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} + diff --git a/src/main/java/com/github/archongum/trino/udf/scalar/ArrayMaxCountElementFunction.java b/src/main/java/com/github/archongum/trino/udf/scalar/ArrayMaxCountElementFunction.java new file mode 100644 index 0000000..5ded870 --- /dev/null +++ b/src/main/java/com/github/archongum/trino/udf/scalar/ArrayMaxCountElementFunction.java @@ -0,0 +1,103 @@ +package com.github.archongum.trino.udf.scalar; + +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.spi.block.Block; +import io.trino.spi.function.Description; +import io.trino.spi.function.ScalarFunction; +import io.trino.spi.function.SqlNullable; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; +import io.trino.spi.type.Type; + + +/** + * @author Archon 2019/8/30 + */ +@ScalarFunction("array_max_count_element") +@Description("Get maximum count element of array (null is not counting; if has multiple return one of them)") +public final class ArrayMaxCountElementFunction { + private ArrayMaxCountElementFunction() {} + + @TypeParameter("T") + @SqlType("T") + @SqlNullable + public static Slice sliceArrayMaxCountElement( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block block + ) { + if (block.getPositionCount() == 0) { + return null; + } + + Map map = new HashMap<>(16); + + for (int i = 0; i < block.getPositionCount(); i++) { + if (block.isNull(i)) { + continue; + } + Slice curElement = elementType.getSlice(block, i); + Long c = map.get(curElement); + if (c == null) { + map.put(curElement, 1L); + } else { + map.put(curElement, c+1); + } + } + + if (map.isEmpty()) { + return null; + } + + Optional> max = map.entrySet().stream().max(Entry.comparingByValue()); + + if (max.isPresent()) { + return max.get().getKey(); + } + + return Slices.EMPTY_SLICE; + } + + @TypeParameter("T") + @SqlType("T") + @SqlNullable + public static Long longArrayMaxCountElement( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block block + ) { + if (block.getPositionCount() == 0) { + return null; + } + + Map map = new HashMap<>(16); + + for (int i = 0; i < block.getPositionCount(); i++) { + if (block.isNull(i)) { + continue; + } + Long curElement = elementType.getLong(block, i); + Long c = map.get(curElement); + if (c == null) { + map.put(curElement, 1L); + } else { + map.put(curElement, c+1); + } + } + + if (map.isEmpty()) { + return null; + } + + Optional> max = map.entrySet().stream().max(Entry.comparingByValue()); + + if (max.isPresent()) { + return max.get().getKey(); + } + + return 0L; + } +} diff --git a/src/main/java/com/github/archongum/trino/udf/scalar/CommonFunctions.java b/src/main/java/com/github/archongum/trino/udf/scalar/CommonFunctions.java new file mode 100644 index 0000000..da06e3a --- /dev/null +++ b/src/main/java/com/github/archongum/trino/udf/scalar/CommonFunctions.java @@ -0,0 +1,25 @@ +package com.github.archongum.trino.udf.scalar; + +import java.util.Random; +import io.airlift.slice.Slice; +import io.trino.spi.function.Description; +import io.trino.spi.function.ScalarFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.type.StandardTypes; + + +/** + * Random with string type seed + * + * @author Archon 2018/9/20 + * @since + */ +public class CommonFunctions { + + @Description("rand(String seed)") + @ScalarFunction("rand") + @SqlType(StandardTypes.DOUBLE) + public static double randomWithSeed(@SqlType(StandardTypes.VARCHAR) Slice seed) { + return new Random(seed.toStringUtf8().hashCode()).nextDouble(); + } +} diff --git a/src/main/java/com/github/archongum/trino/udf/scalar/DateTimeFunctions.java b/src/main/java/com/github/archongum/trino/udf/scalar/DateTimeFunctions.java new file mode 100644 index 0000000..71a2cae --- /dev/null +++ b/src/main/java/com/github/archongum/trino/udf/scalar/DateTimeFunctions.java @@ -0,0 +1,66 @@ +package com.github.archongum.trino.udf.scalar; + +import java.time.LocalDate; +import java.time.LocalDateTime; +import io.airlift.slice.Slice; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.Description; +import io.trino.spi.function.ScalarFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.type.DateTimeEncoding; +import io.trino.spi.type.StandardTypes; +import static com.github.archongum.trino.udf.scalar.DateTimeUtils.OFFSET_MILLIS; +import static java.util.concurrent.TimeUnit.DAYS; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + + +/** + * Date Time Functions. + * + * @author Archon 2018/9/20 + * @since + */ +public class DateTimeFunctions { + + @Description("yesterday") + @ScalarFunction("yesterday") + @SqlType(StandardTypes.DATE) + public static long yesterday() { + return MILLISECONDS.toDays(System.currentTimeMillis()) - 1; + } + + @Description("first day of month") + @ScalarFunction("first_day") + @SqlType(StandardTypes.DATE) + public static long firstDay(@SqlType(StandardTypes.DATE) long days) { + return DateTimeUtils.firstDayOfMonth((int) days).toEpochDay(); + } + + @Description("last day of month") + @ScalarFunction("last_day") + @SqlType(StandardTypes.DATE) + public static long lastDay(@SqlType(StandardTypes.DATE) long days) { + return DateTimeUtils.lastDayOfMonth((int) days).toEpochDay(); + } + + @Description("last second of the date") + @ScalarFunction("last_second") + @SqlType("timestamp(3) with time zone") + public static long lastSecond(ConnectorSession session, @SqlType(StandardTypes.DATE) long days) { + return DateTimeEncoding.packDateTimeWithZone(DateTimeUtils.toMillis(LocalDateTime.of(DateTimeUtils.toLocalDate((int) days), DateTimeUtils.LAST_SECOND)), session.getTimeZoneKey()); + } + + @Description("yesterday last second") + @ScalarFunction("yesterday_last_second") + @SqlType("timestamp(3) with time zone") + public static long yesterdayLastSecond(ConnectorSession session) { + return DateTimeEncoding.packDateTimeWithZone(DAYS.toMillis(MILLISECONDS.toDays(System.currentTimeMillis())) - OFFSET_MILLIS - 1, session.getTimeZoneKey()); + } + + @Description("to timestamp") + @ScalarFunction("to_datetime") + @SqlType("timestamp(3) with time zone") + public static long toDatetime(ConnectorSession session, @SqlType(StandardTypes.DATE) long days, @SqlType(StandardTypes.VARCHAR) Slice time) { + return DateTimeEncoding.packDateTimeWithZone(DateTimeUtils.toMillis(DateTimeUtils.toLocalDateTime(LocalDate.ofEpochDay(days), time.toStringUtf8())), session.getTimeZoneKey()); + } +} diff --git a/src/main/java/com/github/archongum/trino/udf/scalar/DateTimeUtils.java b/src/main/java/com/github/archongum/trino/udf/scalar/DateTimeUtils.java new file mode 100644 index 0000000..33e6944 --- /dev/null +++ b/src/main/java/com/github/archongum/trino/udf/scalar/DateTimeUtils.java @@ -0,0 +1,102 @@ +package com.github.archongum.trino.udf.scalar; + +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.ZoneOffset; +import java.util.Calendar; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; + + +/** + * @author Archon 2018/9/20 + * @since + */ +public class DateTimeUtils { + public static final int OFFSET_MILLIS = Calendar.getInstance().getTimeZone().getRawOffset(); + + public static final int OFFSET_SECOND = OFFSET_MILLIS/1000; + + public static final ZoneOffset ZONE_OFFSET = ZoneOffset.ofTotalSeconds(OFFSET_SECOND); + + public static final LocalTime LAST_SECOND = LocalTime.of(23, 59, 59, 999999999); + + // ----------------------------- base ------------------------ // + public static long toMillis(LocalDateTime dateTime) { + return dateTime.toEpochSecond(ZONE_OFFSET) * 1000 + dateTime.getNano() / 1000_000; + } + public static long toMillis(LocalDate date) { + return date.toEpochDay() * 86400000; + } + public static long toSeconds(long millis) { + return MILLISECONDS.toSeconds(millis); + } + public static long toMinutes(long millis) { + return MILLISECONDS.toMinutes(millis); + } + public static long toHours(long millis) { + return MILLISECONDS.toHours(millis); + } + public static int toDays(long millis) { + return (int) MILLISECONDS.toDays(millis); + } + + // base + public static LocalDateTime toLocalDateTime(long millis) { + return LocalDateTime.ofEpochSecond(toSeconds(millis), 0, ZONE_OFFSET); + } + + public static LocalDateTime toLocalDateTime(LocalDate date, String time) { + return LocalDateTime.of(date, LocalTime.parse(time)); + } + + public static LocalDate toLocalDate(int days) { + return LocalDate.ofEpochDay(days); + } + + public static LocalDate toLocalDate(long millis) { + return toLocalDate(toDays(millis)); + } + + + // --------------------- first day of month ---------------------------// + public static LocalDate firstDayOfMonth(LocalDate date) { + return date.withDayOfMonth(1); + } + + public static LocalDate firstDayOfMonth(LocalDateTime dateTime) { + return firstDayOfMonth(dateTime.toLocalDate()); + } + + public static LocalDate firstDayOfMonth(long millis) { + return firstDayOfMonth(toLocalDate(millis)); + } + + public static LocalDate firstDayOfMonth(int days) { + return firstDayOfMonth(toLocalDate(days)); + } + + // --------------------- last day of month ---------------------------// + public static LocalDate lastDayOfMonth(LocalDate date) { + int month = date.getMonthValue(); + if (month < 12) { + return LocalDate.of(date.getYear(), month+1, 1).minusDays(1); + } else { + return LocalDate.of(date.getYear(), month, 31); + } + } + + public static LocalDate lastDayOfMonth(LocalDateTime dateTime) { + return lastDayOfMonth(dateTime.toLocalDate()); + } + + public static LocalDate lastDayOfMonth(long millis) { + return lastDayOfMonth(toLocalDate(millis)); + } + + public static LocalDate lastDayOfMonth(int days) { + return lastDayOfMonth(toLocalDate(days)); + } + +} diff --git a/src/main/resources/META-INF/services/io.trino.spi.Plugin b/src/main/resources/META-INF/services/io.trino.spi.Plugin new file mode 100644 index 0000000..d27d834 --- /dev/null +++ b/src/main/resources/META-INF/services/io.trino.spi.Plugin @@ -0,0 +1 @@ +com.github.archongum.trino.UdfPlugin diff --git a/src/test/com/github/archongum/trino/udf/aggregate/TestAggreateFunctions.java b/src/test/com/github/archongum/trino/udf/aggregate/TestAggreateFunctions.java new file mode 100644 index 0000000..9accbdc --- /dev/null +++ b/src/test/com/github/archongum/trino/udf/aggregate/TestAggreateFunctions.java @@ -0,0 +1,39 @@ +package com.github.archongum.trino.udf.aggregate; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import org.junit.jupiter.api.Test; + + +/** + * @author Archon 2019/8/29 + * @since + */ +class TestAggreateFunctions { + + @Test + void testTopN() { + List list = Arrays.asList( + "a", "a", "a", + "b", "b", + "c", + "a", "a" + ); + + Map map = new HashMap<>(); + + for (String curElement : list) { + Long c = map.get(curElement); + if (c == null) { + map.put(curElement, 1L); + } else { + map.put(curElement, c + 1); + } + } + + System.out.println(new HashMap().entrySet().stream().max(Entry.comparingByValue()).get()); + } +} diff --git a/src/test/com/github/archongum/trino/udf/scalar/TestDateTimeFunctions.java b/src/test/com/github/archongum/trino/udf/scalar/TestDateTimeFunctions.java new file mode 100644 index 0000000..6d1dbba --- /dev/null +++ b/src/test/com/github/archongum/trino/udf/scalar/TestDateTimeFunctions.java @@ -0,0 +1,39 @@ +package com.github.archongum.trino.udf.scalar; + +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.ZoneOffset; +import org.junit.jupiter.api.Test; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + + +/** + * @author Archon 2018/9/20 + * @since + */ +class TestDateTimeFunctions { + + @Test + void testLastDay() { + long millis = System.currentTimeMillis(); + System.out.println(DateTimeFunctions.lastDay(DateTimeUtils.toDays(millis))); + System.out.println(MILLISECONDS.toDays(millis)); + System.out.println(LocalTime.parse("23:59:59.999")); + LocalDateTime ts = LocalDateTime.of(LocalDate.now(), LocalTime.parse("23:59:59.999")); + System.out.println(ts.toEpochSecond(ZoneOffset.UTC)*1000 + ts.getNano() / 1000_000); + } + + @Test + void testToDatetime() { + long millis = System.currentTimeMillis(); + long d1 = DateTimeUtils.toLocalDateTime(millis).toLocalDate().toEpochDay(); + System.out.println(d1); + System.out.println(millis/86400000); + } + + @Test + void testYesterday() { + System.out.println(DateTimeFunctions.yesterday()); + } +} diff --git a/unit_test.sql b/unit_test.sql new file mode 100644 index 0000000..483bdcd --- /dev/null +++ b/unit_test.sql @@ -0,0 +1,47 @@ +-- [Unit Test] Scalar Functions +with a as ( + select date '2022-05-16' as d +) +-- Expected Output: +-- 0.47325298871196086 +-- 1 +-- 2022-05-01 +-- 2022-05-31 +-- 2022-05-15 +-- 2022-05-15 23:59:59.999 +08:00 +-- 2022-05-16 23:59:59.999 +08:00 +-- 2022-05-16 23:59:59.999 +08:00 +select + rand('123') as rand, + array_max_count_element(array['1', '2', '1']) as array_max_count_element, + first_day(d) as first_day, + last_day(d) as last_day, + yesterday() as yesterday, + yesterday_last_second() as yesterday_last_second, + last_second(d) as last_second, + to_datetime(d, '23:59:59.999999') as to_datetime +from a; + + +-- [Unit Test] Aggregate Functions +with a as ( + select 'guangzhou' as city, 'tianhe' as district, array[1,2,3] as item_ids + union all + select 'guangzhou' as city, 'baiyun' as district, array[2,3,4] as item_ids + union all + select 'guangzhou' as city, 'baiyun' as district, null as item_ids + union all + select 'shenzhen' as city, 'futian' as district, array[1,2] as item_ids + union all + select 'shenzhen' as city, 'nanshan' as district, array[1,3,3] as item_ids +) +-- Expected Output: +-- guangzhou baiyun 4 +-- shenzhen futian 3 +select + city, + max_count_element(district), + array_agg_distinct_integer(item_ids) as sku_cnt +from a +group by city +order by city;