Skip to content

Commit

Permalink
Rename AggregateFunction#inputs to arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
assaf2 authored and losipiuk committed Feb 23, 2022
1 parent 55d9ac2 commit 083035d
Show file tree
Hide file tree
Showing 32 changed files with 192 additions and 192 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,28 @@ public class AggregateFunction
{
private final String functionName;
private final Type outputType;
private final List<ConnectorExpression> inputs;
private final List<ConnectorExpression> arguments;
private final List<SortItem> sortItems;
private final boolean isDistinct;
private final Optional<ConnectorExpression> filter;

public AggregateFunction(
String aggregateFunctionName,
Type outputType,
List<ConnectorExpression> inputs,
List<ConnectorExpression> arguments,
List<SortItem> sortItems,
boolean isDistinct,
Optional<ConnectorExpression> filter)
{
if (isDistinct && inputs.isEmpty()) {
throw new IllegalArgumentException("DISTINCT requires inputs");
if (isDistinct && arguments.isEmpty()) {
throw new IllegalArgumentException("DISTINCT requires arguments");
}

this.functionName = requireNonNull(aggregateFunctionName, "aggregateFunctionName is null");
this.outputType = requireNonNull(outputType, "outputType is null");
requireNonNull(inputs, "inputs is null");
requireNonNull(arguments, "arguments is null");
requireNonNull(sortItems, "sortItems is null");
this.inputs = List.copyOf(inputs);
this.arguments = List.copyOf(arguments);
this.sortItems = List.copyOf(sortItems);
this.isDistinct = isDistinct;
this.filter = requireNonNull(filter, "filter is null");
Expand All @@ -59,9 +59,9 @@ public String getFunctionName()
return functionName;
}

public List<ConnectorExpression> getInputs()
public List<ConnectorExpression> getArguments()
{
return inputs;
return arguments;
}

public Type getOutputType()
Expand Down Expand Up @@ -89,7 +89,7 @@ public String toString()
{
return new StringJoiner(", ", AggregateFunction.class.getSimpleName() + "[", "]")
.add("aggregationName='" + functionName + "'")
.add("inputs=" + inputs)
.add("arguments=" + arguments)
.add("outputType=" + outputType)
.add("sortOrder=" + sortItems)
.add("isDistinct=" + isDistinct)
Expand All @@ -111,7 +111,7 @@ public boolean equals(Object o)
AggregateFunction that = (AggregateFunction) o;
return isDistinct == that.isDistinct &&
Objects.equals(functionName, that.functionName) &&
Objects.equals(inputs, that.inputs) &&
Objects.equals(arguments, that.arguments) &&
Objects.equals(outputType, that.outputType) &&
Objects.equals(sortItems, that.sortItems) &&
Objects.equals(filter, that.filter);
Expand All @@ -120,6 +120,6 @@ public boolean equals(Object o)
@Override
public int hashCode()
{
return Objects.hash(functionName, inputs, outputType, sortItems, isDistinct, filter);
return Objects.hash(functionName, arguments, outputType, sortItems, isDistinct, filter);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,19 @@ public static Pattern<AggregateFunction> basicAggregation()
return Property.property("outputType", AggregateFunction::getOutputType);
}

public static Property<AggregateFunction, ?, List<ConnectorExpression>> inputs()
public static Property<AggregateFunction, ?, List<ConnectorExpression>> arguments()
{
return Property.property("inputs", AggregateFunction::getInputs);
return Property.property("arguments", AggregateFunction::getArguments);
}

public static Property<AggregateFunction, ?, ConnectorExpression> singleInput()
public static Property<AggregateFunction, ?, ConnectorExpression> singleArgument()
{
return Property.optionalProperty("inputs", aggregateFunction -> {
List<ConnectorExpression> inputs = aggregateFunction.getInputs();
if (inputs.size() != 1) {
return Property.optionalProperty("arguments", aggregateFunction -> {
List<ConnectorExpression> arguments = aggregateFunction.getArguments();
if (arguments.size() != 1) {
return Optional.empty();
}
return Optional.of(inputs.get(0));
return Optional.of(arguments.get(0));
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.expressionType;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleArgument;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.variable;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DoubleType.DOUBLE;
Expand All @@ -45,29 +45,29 @@
public abstract class BaseImplementAvgBigint
implements AggregateFunctionRule<JdbcExpression>
{
private final Capture<Variable> input;
private final Capture<Variable> argument;

public BaseImplementAvgBigint()
{
this.input = newCapture();
this.argument = newCapture();
}

@Override
public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("avg"))
.with(singleInput().matching(
.with(singleArgument().matching(
variable()
.with(expressionType().matching(type -> type == BIGINT))
.capturedAs(this.input)));
.capturedAs(this.argument)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
{
Variable input = captures.get(this.input);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
Variable argument = captures.get(this.argument);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName());
verify(aggregateFunction.getOutputType() == DOUBLE);

String columnName = context.getIdentifierQuote().apply(columnHandle.getColumnName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.expressionType;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleArgument;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.variable;
import static java.lang.String.format;

Expand All @@ -40,24 +40,24 @@
public class ImplementAvgDecimal
implements AggregateFunctionRule<JdbcExpression>
{
private static final Capture<Variable> INPUT = newCapture();
private static final Capture<Variable> ARGUMENT = newCapture();

@Override
public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("avg"))
.with(singleInput().matching(
.with(singleArgument().matching(
variable()
.with(expressionType().matching(DecimalType.class::isInstance))
.capturedAs(INPUT)));
.capturedAs(ARGUMENT)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
{
Variable input = captures.get(INPUT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
Variable argument = captures.get(ARGUMENT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName());
DecimalType type = (DecimalType) columnHandle.getColumnType();
verify(aggregateFunction.getOutputType().equals(type));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.expressionType;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleArgument;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.variable;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.RealType.REAL;
Expand All @@ -41,24 +41,24 @@
public class ImplementAvgFloatingPoint
implements AggregateFunctionRule<JdbcExpression>
{
private static final Capture<Variable> INPUT = newCapture();
private static final Capture<Variable> ARGUMENT = newCapture();

@Override
public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("avg"))
.with(singleInput().matching(
.with(singleArgument().matching(
variable()
.with(expressionType().matching(type -> type == REAL || type == DOUBLE))
.capturedAs(INPUT)));
.capturedAs(ARGUMENT)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
{
Variable input = captures.get(INPUT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
Variable argument = captures.get(ARGUMENT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName());
verify(aggregateFunction.getOutputType() == columnHandle.getColumnType());

return Optional.of(new JdbcExpression(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.arguments;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.expressionTypes;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.inputs;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.variables;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.RealType.REAL;
Expand All @@ -39,27 +39,27 @@
public class ImplementCorr
implements AggregateFunctionRule<JdbcExpression>
{
private static final Capture<List<Variable>> INPUTS = newCapture();
private static final Capture<List<Variable>> ARGUMENTS = newCapture();

@Override
public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("corr"))
.with(inputs().matching(
.with(arguments().matching(
variables()
.matching(expressionTypes(REAL, REAL).or(expressionTypes(DOUBLE, DOUBLE)))
.capturedAs(INPUTS)));
.capturedAs(ARGUMENTS)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
{
List<Variable> inputs = captures.get(INPUTS);
verify(inputs.size() == 2);
List<Variable> arguments = captures.get(ARGUMENTS);
verify(arguments.size() == 2);

JdbcColumnHandle columnHandle1 = (JdbcColumnHandle) context.getAssignment(inputs.get(0).getName());
JdbcColumnHandle columnHandle2 = (JdbcColumnHandle) context.getAssignment(inputs.get(1).getName());
JdbcColumnHandle columnHandle1 = (JdbcColumnHandle) context.getAssignment(arguments.get(0).getName());
JdbcColumnHandle columnHandle2 = (JdbcColumnHandle) context.getAssignment(arguments.get(1).getName());
verify(aggregateFunction.getOutputType().equals(columnHandle1.getColumnType()));

return Optional.of(new JdbcExpression(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleArgument;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.variable;
import static io.trino.spi.type.BigintType.BIGINT;
import static java.lang.String.format;
Expand All @@ -43,7 +43,7 @@
public class ImplementCount
implements AggregateFunctionRule<JdbcExpression>
{
private static final Capture<Variable> INPUT = newCapture();
private static final Capture<Variable> ARGUMENT = newCapture();

private final JdbcTypeHandle bigintTypeHandle;

Expand All @@ -60,14 +60,14 @@ public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("count"))
.with(singleInput().matching(variable().capturedAs(INPUT)));
.with(singleArgument().matching(variable().capturedAs(ARGUMENT)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
{
Variable input = captures.get(INPUT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
Variable argument = captures.get(ARGUMENT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName());
verify(aggregateFunction.getOutputType() == BIGINT);

return Optional.of(new JdbcExpression(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.arguments;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.inputs;
import static io.trino.spi.type.BigintType.BIGINT;
import static java.util.Objects.requireNonNull;

Expand All @@ -53,7 +53,7 @@ public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("count"))
.with(inputs().equalTo(List.of()));
.with(arguments().equalTo(List.of()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.distinct;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.hasFilter;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleArgument;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.variable;
import static io.trino.spi.type.BigintType.BIGINT;
import static java.lang.String.format;
Expand All @@ -46,7 +46,7 @@
public class ImplementCountDistinct
implements AggregateFunctionRule<JdbcExpression>
{
private static final Capture<Variable> INPUT = newCapture();
private static final Capture<Variable> ARGUMENT = newCapture();

private final JdbcTypeHandle bigintTypeHandle;
private final boolean isRemoteCollationSensitive;
Expand All @@ -67,14 +67,14 @@ public Pattern<AggregateFunction> getPattern()
.with(distinct().equalTo(true))
.with(hasFilter().equalTo(false))
.with(functionName().equalTo("count"))
.with(singleInput().matching(variable().capturedAs(INPUT)));
.with(singleArgument().matching(variable().capturedAs(ARGUMENT)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
{
Variable input = captures.get(INPUT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
Variable argument = captures.get(ARGUMENT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName());
verify(aggregateFunction.getOutputType() == BIGINT);

boolean isCaseSensitiveType = columnHandle.getColumnType() instanceof CharType || columnHandle.getColumnType() instanceof VarcharType;
Expand Down
Loading

0 comments on commit 083035d

Please sign in to comment.