Skip to content

Commit

Permalink
snowflake/cdk-td-changes
Browse files Browse the repository at this point in the history
  • Loading branch information
gisripa committed Feb 20, 2024
1 parent eca8629 commit 6096fd1
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ plugins {
airbyteJavaConnector {
cdkVersionRequired = '0.20.9'
features = ['db-destinations', 's3-destinations', 'typing-deduping']
useLocalCdk = false
useLocalCdk = true
}

java {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import io.airbyte.cdk.integrations.base.TypingAndDedupingFlag;
import io.airbyte.cdk.integrations.destination.NamingConventionTransformer;
import io.airbyte.cdk.integrations.destination.jdbc.AbstractJdbcDestination;
import io.airbyte.cdk.integrations.destination.jdbc.typing_deduping.JdbcDestinationHandler;
import io.airbyte.cdk.integrations.destination.jdbc.typing_deduping.JdbcSqlGenerator;
import io.airbyte.cdk.integrations.destination.staging.StagingConsumerFactory;
import io.airbyte.commons.json.Jsons;
Expand Down Expand Up @@ -129,6 +130,11 @@ protected JdbcSqlGenerator getSqlGenerator() {
throw new UnsupportedOperationException("Snowflake does not yet use the native JDBC DV2 interface");
}

@Override
protected JdbcDestinationHandler getDestinationHandler(String databaseName, JdbcDatabase database) {
throw new UnsupportedOperationException("Snowflake does not yet use the native JDBC DV2 interface");
}

@Override
public SerializedAirbyteMessageConsumer getSerializedMessageConsumer(final JsonNode config,
final ConfiguredAirbyteCatalog catalog,
Expand Down Expand Up @@ -158,7 +164,7 @@ public SerializedAirbyteMessageConsumer getSerializedMessageConsumer(final JsonN
final boolean disableTypeDedupe = config.has(DISABLE_TYPE_DEDUPE) && config.get(DISABLE_TYPE_DEDUPE).asBoolean(false);
final int defaultThreadCount = 8;
if (disableTypeDedupe) {
typerDeduper = new NoOpTyperDeduperWithV1V2Migrations<>(sqlGenerator, snowflakeDestinationHandler, parsedCatalog, migrator, v2TableMigrator,
typerDeduper = new NoOpTyperDeduperWithV1V2Migrations(sqlGenerator, snowflakeDestinationHandler, parsedCatalog, migrator, v2TableMigrator,
defaultThreadCount);
} else {
typerDeduper =
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,43 @@

package io.airbyte.integrations.destination.snowflake.typing_deduping;

import static io.airbyte.cdk.integrations.base.JavaBaseConstants.COLUMN_NAME_AB_META;

import io.airbyte.cdk.db.jdbc.JdbcDatabase;
import io.airbyte.cdk.integrations.base.JavaBaseConstants;
import io.airbyte.cdk.integrations.destination.jdbc.ColumnDefinition;
import io.airbyte.cdk.integrations.destination.jdbc.TableDefinition;
import io.airbyte.cdk.integrations.destination.jdbc.typing_deduping.JdbcDestinationHandler;
import io.airbyte.integrations.base.destination.typing_deduping.AirbyteProtocolType;
import io.airbyte.integrations.base.destination.typing_deduping.AirbyteType;
import io.airbyte.integrations.base.destination.typing_deduping.Array;
import io.airbyte.integrations.base.destination.typing_deduping.ColumnId;
import io.airbyte.integrations.base.destination.typing_deduping.DestinationHandler;
import io.airbyte.integrations.base.destination.typing_deduping.DestinationInitialState;
import io.airbyte.integrations.base.destination.typing_deduping.InitialRawTableState;
import io.airbyte.integrations.base.destination.typing_deduping.Sql;
import io.airbyte.integrations.base.destination.typing_deduping.StreamConfig;
import io.airbyte.integrations.base.destination.typing_deduping.StreamId;
import io.airbyte.integrations.base.destination.typing_deduping.Struct;
import io.airbyte.integrations.base.destination.typing_deduping.Union;
import io.airbyte.integrations.base.destination.typing_deduping.UnsupportedOneOf;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.time.Instant;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import net.snowflake.client.jdbc.SnowflakeSQLException;
import org.apache.commons.text.StringSubstitutor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SnowflakeDestinationHandler implements DestinationHandler<SnowflakeTableDefinition> {
public class SnowflakeDestinationHandler extends JdbcDestinationHandler {

private static final Logger LOGGER = LoggerFactory.getLogger(SnowflakeDestinationHandler.class);
public static final String EXCEPTION_COMMON_PREFIX = "JavaScript execution error: Uncaught Execution of multiple statements failed on statement";
Expand All @@ -30,60 +49,58 @@ public class SnowflakeDestinationHandler implements DestinationHandler<Snowflake
private final JdbcDatabase database;

public SnowflakeDestinationHandler(final String databaseName, final JdbcDatabase database) {
super(databaseName, database);
this.databaseName = databaseName;
this.database = database;
}

@Override
public Optional<SnowflakeTableDefinition> findExistingTable(final StreamId id) throws SQLException {
public Optional<TableDefinition> findExistingTable(final StreamId id) throws SQLException {
// The obvious database.getMetaData().getColumns() solution doesn't work, because JDBC translates
// VARIANT as VARCHAR
final LinkedHashMap<String, SnowflakeColumnDefinition> columns = database.queryJsons(
"""
SELECT column_name, data_type, is_nullable
FROM information_schema.columns
WHERE table_catalog = ?
AND table_schema = ?
AND table_name = ?
ORDER BY ordinal_position;
""",
databaseName.toUpperCase(),
id.finalNamespace().toUpperCase(),
id.finalName().toUpperCase()).stream()
final LinkedHashMap<String, ColumnDefinition> columns = database.queryJsons(
"""
SELECT column_name, data_type, is_nullable
FROM information_schema.columns
WHERE table_catalog = ?
AND table_schema = ?
AND table_name = ?
ORDER BY ordinal_position;
""",
databaseName.toUpperCase(),
id.finalNamespace().toUpperCase(),
id.finalName().toUpperCase()).stream()
.collect(LinkedHashMap::new,
(map, row) -> map.put(
row.get("COLUMN_NAME").asText(),
new SnowflakeColumnDefinition(row.get("DATA_TYPE").asText(), fromSnowflakeBoolean(row.get("IS_NULLABLE").asText()))),
LinkedHashMap::putAll);
(map, row) -> map.put(
row.get("COLUMN_NAME").asText(),
new ColumnDefinition(
row.get("COLUMN_NAME").asText(),
row.get("DATA_TYPE").asText(),
0, //unused
fromIsNullableIsoString(row.get("IS_NULLABLE").asText()))),
LinkedHashMap::putAll);
if (columns.isEmpty()) {
return Optional.empty();
} else {
return Optional.of(new SnowflakeTableDefinition(columns));
return Optional.of(new TableDefinition(columns));
}
}

@Override
public LinkedHashMap<String, SnowflakeTableDefinition> findExistingFinalTables(final List<StreamId> list) throws Exception {
return null;
}

@Override
public boolean isFinalTableEmpty(final StreamId id) throws SQLException {
private boolean isFinalTableEmpty(final StreamId id) throws SQLException {
final int rowCount = database.queryInt(
"""
SELECT row_count
FROM information_schema.tables
WHERE table_catalog = ?
AND table_schema = ?
AND table_name = ?
""",
SELECT row_count
FROM information_schema.tables
WHERE table_catalog = ?
AND table_schema = ?
AND table_name = ?
""",
databaseName.toUpperCase(),
id.finalNamespace().toUpperCase(),
id.finalName().toUpperCase());
return rowCount == 0;
}

@Override

public InitialRawTableState getInitialRawTableState(final StreamId id) throws Exception {
final ResultSet tables = database.getMetaData().getTables(
databaseName,
Expand All @@ -99,7 +116,7 @@ public InitialRawTableState getInitialRawTableState(final StreamId id) throws Ex
final Optional<String> minUnloadedTimestamp = Optional.ofNullable(database.queryStrings(
conn -> conn.createStatement().executeQuery(new StringSubstitutor(Map.of(
"raw_table", id.rawTableId(SnowflakeSqlGenerator.QUOTE))).replace(
"""
"""
SELECT to_varchar(
TIMESTAMPADD(NANOSECOND, -1, MIN("_airbyte_extracted_at")),
'YYYY-MM-DDTHH24:MI:SS.FF9TZH:TZM'
Expand All @@ -118,7 +135,7 @@ record -> record.getString("MIN_TIMESTAMP")).get(0));
final Optional<String> maxTimestamp = Optional.ofNullable(database.queryStrings(
conn -> conn.createStatement().executeQuery(new StringSubstitutor(Map.of(
"raw_table", id.rawTableId(SnowflakeSqlGenerator.QUOTE))).replace(
"""
"""
SELECT to_varchar(
MAX("_airbyte_extracted_at"),
'YYYY-MM-DDTHH24:MI:SS.FF9TZH:TZM'
Expand Down Expand Up @@ -158,12 +175,61 @@ public void execute(final Sql sql) throws Exception {
}
}

/**
* In snowflake information_schema tables, booleans return "YES" and "NO", which DataBind doesn't
* know how to use
*/
private boolean fromSnowflakeBoolean(final String input) {
return input.equalsIgnoreCase("yes");

private Set<String> getPks(final StreamConfig stream) {
return stream.primaryKey() != null ? stream.primaryKey().stream().map(ColumnId::name).collect(Collectors.toSet()) : Collections.emptySet();
}

@Override
protected boolean isAirbyteMetaColumnMatch(TableDefinition existingTable) {
return existingTable.columns().containsKey(COLUMN_NAME_AB_META) &&
"VARIANT".equals(existingTable.columns().get(COLUMN_NAME_AB_META).type());
}

protected boolean existingSchemaMatchesStreamConfig(final StreamConfig stream, final TableDefinition existingTable) {
final Set<String> pks = getPks(stream);
// soft-resetting https://github.com/airbytehq/airbyte/pull/31082
@SuppressWarnings("deprecation") final boolean hasPksWithNonNullConstraint = existingTable.columns().entrySet().stream()
.anyMatch(c -> pks.contains(c.getKey()) && !c.getValue().isNullable());

return !hasPksWithNonNullConstraint
&& super.existingSchemaMatchesStreamConfig(stream, existingTable);

}

@Override
public List<DestinationInitialState> gatherInitialState(List<StreamConfig> streamConfigs) throws Exception {
return null;
}

@Override
protected String toJdbcTypeName(AirbyteType airbyteType) {
if (airbyteType instanceof final AirbyteProtocolType p) {
return toJdbcTypeName(p);
}

return switch (airbyteType.getTypeName()) {
case Struct.TYPE -> "OBJECT";
case Array.TYPE -> "ARRAY";
case UnsupportedOneOf.TYPE -> "VARIANT";
case Union.TYPE -> toJdbcTypeName(((Union) airbyteType).chooseType());
default -> throw new IllegalArgumentException("Unrecognized type: " + airbyteType.getTypeName());
};
}

private String toJdbcTypeName(final AirbyteProtocolType airbyteProtocolType) {
return switch (airbyteProtocolType) {
case STRING -> "TEXT";
case NUMBER -> "FLOAT";
case INTEGER -> "NUMBER";
case BOOLEAN -> "BOOLEAN";
case TIMESTAMP_WITH_TIMEZONE -> "TIMESTAMP_TZ";
case TIMESTAMP_WITHOUT_TIMEZONE -> "TIMESTAMP_NTZ";
// If you change this - also change the logic in extractAndCast
case TIME_WITH_TIMEZONE -> "TEXT";
case TIME_WITHOUT_TIMEZONE -> "TIME";
case DATE -> "DATE";
case UNKNOWN -> "VARIANT";
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringSubstitutor;

public class SnowflakeSqlGenerator implements SqlGenerator<SnowflakeTableDefinition> {
public class SnowflakeSqlGenerator implements SqlGenerator {

public static final String QUOTE = "\"";

Expand Down Expand Up @@ -134,36 +134,6 @@ public Sql createTable(final StreamConfig stream, final String suffix, final boo
"""));
}

@Override
public boolean existingSchemaMatchesStreamConfig(final StreamConfig stream, final SnowflakeTableDefinition existingTable)
throws TableNotMigratedException {
final Set<String> pks = getPks(stream);

// Check that the columns match, with special handling for the metadata columns.
final LinkedHashMap<String, String> intendedColumns = stream.columns().entrySet().stream()
.collect(LinkedHashMap::new,
(map, column) -> map.put(column.getKey().name(), toDialectType(column.getValue())),
LinkedHashMap::putAll);
final LinkedHashMap<String, String> actualColumns = existingTable.columns().entrySet().stream()
.filter(column -> JavaBaseConstants.V2_FINAL_TABLE_METADATA_COLUMNS.stream().map(String::toUpperCase)
.noneMatch(airbyteColumnName -> airbyteColumnName.equals(column.getKey())))
.collect(LinkedHashMap::new,
(map, column) -> map.put(column.getKey(), column.getValue().type()),
LinkedHashMap::putAll);
// soft-resetting https://github.com/airbytehq/airbyte/pull/31082
@SuppressWarnings("deprecation")
final boolean hasPksWithNonNullConstraint = existingTable.columns().entrySet().stream()
.anyMatch(c -> pks.contains(c.getKey()) && !c.getValue().isNullable());

final boolean sameColumns = actualColumns.equals(intendedColumns)
&& !hasPksWithNonNullConstraint
&& "TEXT".equals(existingTable.columns().get(JavaBaseConstants.COLUMN_NAME_AB_RAW_ID.toUpperCase()).type())
&& "TIMESTAMP_TZ".equals(existingTable.columns().get(JavaBaseConstants.COLUMN_NAME_AB_EXTRACTED_AT.toUpperCase()).type())
&& "VARIANT".equals(existingTable.columns().get(JavaBaseConstants.COLUMN_NAME_AB_META.toUpperCase()).type());

return sameColumns;
}

@Override
public Sql updateTable(final StreamConfig stream,
final String finalSuffix,
Expand Down Expand Up @@ -552,8 +522,4 @@ public static String escapeSingleQuotedString(final String str) {
.replace("'", "\\'");
}

private static Set<String> getPks(final StreamConfig stream) {
return stream.primaryKey() != null ? stream.primaryKey().stream().map(ColumnId::name).collect(Collectors.toSet()) : Collections.emptySet();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ protected void globalTeardown() throws Exception {
}

@Override
protected SqlGenerator<?> getSqlGenerator() {
protected SqlGenerator getSqlGenerator() {
return new SnowflakeSqlGenerator();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.airbyte.commons.io.IOs;
import io.airbyte.commons.json.Jsons;
import io.airbyte.integrations.base.destination.typing_deduping.BaseSqlGeneratorIntegrationTest;
import io.airbyte.integrations.base.destination.typing_deduping.DestinationInitialState;
import io.airbyte.integrations.base.destination.typing_deduping.Sql;
import io.airbyte.integrations.base.destination.typing_deduping.StreamId;
import io.airbyte.integrations.destination.snowflake.OssCloudEnvVarConsts;
Expand All @@ -44,7 +45,7 @@
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

public class SnowflakeSqlGeneratorIntegrationTest extends BaseSqlGeneratorIntegrationTest<SnowflakeTableDefinition> {
public class SnowflakeSqlGeneratorIntegrationTest extends BaseSqlGeneratorIntegrationTest {

private static String databaseName;
private static JdbcDatabase database;
Expand Down Expand Up @@ -411,8 +412,9 @@ public void ensurePKsAreIndexedUnique() throws Exception {

// should be OK with new tables
destinationHandler.execute(createTable);
final Optional<SnowflakeTableDefinition> existingTableA = destinationHandler.findExistingTable(streamId);
assertTrue(generator.existingSchemaMatchesStreamConfig(incrementalDedupStream, existingTableA.get()));
List<DestinationInitialState> initialStates = destinationHandler.gatherInitialState(List.of(incrementalDedupStream));
assertEquals(1, initialStates.size());
assertFalse(initialStates.get(0).isSchemaMismatch());
destinationHandler.execute(Sql.of("DROP TABLE " + streamId.finalTableId("")));

// Hack the create query to add NOT NULLs to emulate the old behavior
Expand All @@ -424,8 +426,9 @@ public void ensurePKsAreIndexedUnique() throws Exception {
.collect(joining("\r\n")))
.toList()).toList();
destinationHandler.execute(new Sql(createTableModified));
final Optional<SnowflakeTableDefinition> existingTableB = destinationHandler.findExistingTable(streamId);
assertFalse(generator.existingSchemaMatchesStreamConfig(incrementalDedupStream, existingTableB.get()));
initialStates = destinationHandler.gatherInitialState(List.of(incrementalDedupStream));
assertEquals(1, initialStates.size());
assertTrue(initialStates.get(0).isSchemaMismatch());
}

}

0 comments on commit 6096fd1

Please sign in to comment.