diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 9d4cf239aadf..5bd9fa31e1d6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -3315,14 +3315,24 @@ protected Scope visitUpdate(Update update, Optional scope) ImmutableList.Builder analysesBuilder = ImmutableList.builder(); ImmutableList.Builder expressionTypesBuilder = ImmutableList.builder(); + ImmutableMap.Builder> sourceColumnsByColumnNameBuilder = ImmutableMap.builder(); for (UpdateAssignment assignment : update.getAssignments()) { + String targetColumnName = assignment.getName().getValue(); Expression expression = assignment.getValue(); - ExpressionAnalysis analysis = analyzeExpression(expression, tableScope); - analysesBuilder.add(analysis); - expressionTypesBuilder.add(analysis.getType(expression)); + ExpressionAnalysis expressionAnalysis = analyzeExpression(expression, tableScope); + analysesBuilder.add(expressionAnalysis); + expressionTypesBuilder.add(expressionAnalysis.getType(expression)); + + Set sourceColumns = expressionAnalysis.getSubqueries().stream() + .map(query -> analyze(query.getNode(), tableScope)) + .flatMap(subqueryScope -> subqueryScope.getRelationType().getVisibleFields().stream()) + .flatMap(field -> analysis.getSourceColumns(field).stream()) + .collect(toImmutableSet()); + sourceColumnsByColumnNameBuilder.put(targetColumnName, sourceColumns); } List analyses = analysesBuilder.build(); List expressionTypes = expressionTypesBuilder.build(); + Map> sourceColumnsByColumnName = sourceColumnsByColumnNameBuilder.buildOrThrow(); List tableTypes = update.getAssignments().stream() .map(assignment -> requireNonNull(columns.get(assignment.getName().getValue()))) @@ -3353,7 +3363,9 @@ protected Scope visitUpdate(Update update, Optional scope) tableName, Optional.of(table), Optional.of(updatedColumnSchemas.stream() - .map(column -> new OutputColumn(new Column(column.getName(), column.getType().toString()), ImmutableSet.of())) + .map(column -> new OutputColumn( + new Column(column.getName(), column.getType().toString()), + sourceColumnsByColumnName.getOrDefault(column.getName(), ImmutableSet.of()))) .collect(toImmutableList()))); createMergeAnalysis(table, handle, tableSchema, tableScope, tableScope, ImmutableList.of(updatedColumnHandles)); diff --git a/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java b/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java index 7f659e40a094..35523dc1ae99 100644 --- a/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java +++ b/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java @@ -1077,6 +1077,65 @@ public void testOutputColumnsForUpdatingSingleColumn() .containsExactly(new OutputColumnMetadata("test_varchar", VARCHAR_TYPE, ImmutableSet.of())); } + @Test + public void testOutputColumnsForUpdatingColumnWithSelectQuery() + throws Exception + { + QueryEvents queryEvents = runQueryAndWaitForEvents("UPDATE mock.default.table_for_output SET test_varchar = (SELECT name from nation LIMIT 1)").getQueryEvents(); + QueryCompletedEvent event = queryEvents.getQueryCompletedEvent(); + assertThat(event.getIoMetadata().getOutput().get().getColumns().get()) + .containsExactly(new OutputColumnMetadata("test_varchar", VARCHAR_TYPE, ImmutableSet.of(new ColumnDetail("tpch", "tiny", "nation", "name")))); + } + + @Test + public void testOutputColumnsForUpdatingColumnWithSelectQueryWithAliasedField() + throws Exception + { + QueryEvents queryEvents = runQueryAndWaitForEvents("UPDATE mock.default.table_for_output SET test_varchar = (SELECT name AS aliased_name from nation LIMIT 1)").getQueryEvents(); + QueryCompletedEvent event = queryEvents.getQueryCompletedEvent(); + assertThat(event.getIoMetadata().getOutput().get().getColumns().get()) + .containsExactly(new OutputColumnMetadata("test_varchar", VARCHAR_TYPE, ImmutableSet.of(new ColumnDetail("tpch", "tiny", "nation", "name")))); + } + + @Test + public void testOutputColumnsForUpdatingColumnsWithSelectQueries() + throws Exception + { + QueryEvents queryEvents = runQueryAndWaitForEvents(""" + UPDATE mock.default.table_for_output SET test_varchar = (SELECT name AS aliased_name from nation LIMIT 1), test_bigint = (SELECT nationkey FROM nation LIMIT 1) + """).getQueryEvents(); + QueryCompletedEvent event = queryEvents.getQueryCompletedEvent(); + assertThat(event.getIoMetadata().getOutput().get().getColumns().get()) + .containsExactlyInAnyOrder( + new OutputColumnMetadata("test_varchar", VARCHAR_TYPE, ImmutableSet.of(new ColumnDetail("tpch", "tiny", "nation", "name"))), + new OutputColumnMetadata("test_bigint", BIGINT_TYPE, ImmutableSet.of(new ColumnDetail("tpch", "tiny", "nation", "nationkey")))); + } + + @Test + public void testOutputColumnsForUpdatingColumnsWithSelectQueryAndRawValue() + throws Exception + { + QueryEvents queryEvents = runQueryAndWaitForEvents(""" + UPDATE mock.default.table_for_output SET test_varchar = (SELECT name AS aliased_name from nation LIMIT 1), test_bigint = 1 + """).getQueryEvents(); + QueryCompletedEvent event = queryEvents.getQueryCompletedEvent(); + assertThat(event.getIoMetadata().getOutput().get().getColumns().get()) + .containsExactlyInAnyOrder( + new OutputColumnMetadata("test_varchar", VARCHAR_TYPE, ImmutableSet.of(new ColumnDetail("tpch", "tiny", "nation", "name"))), + new OutputColumnMetadata("test_bigint", BIGINT_TYPE, ImmutableSet.of())); + } + + @Test + public void testOutputColumnsForUpdatingColumnWithSelectQueryAndWhereClauseWithOuterColumn() + throws Exception + { + QueryEvents queryEvents = runQueryAndWaitForEvents(""" + UPDATE mock.default.table_for_output SET test_varchar = (SELECT name from nation WHERE test_bigint = nationkey)""").getQueryEvents(); + QueryCompletedEvent event = queryEvents.getQueryCompletedEvent(); + assertThat(event.getIoMetadata().getOutput().get().getColumns().get()) + .containsExactly(new OutputColumnMetadata("test_varchar", VARCHAR_TYPE, ImmutableSet.of(new ColumnDetail("tpch", "tiny", "nation", "name")))); + } + @Test public void testCreateTable() throws Exception