diff --git a/airbyte-integrations/connectors/destination-snowflake/gradle.properties b/airbyte-integrations/connectors/destination-snowflake/gradle.properties index ade3d02d02d..21f0eda92a4 100644 --- a/airbyte-integrations/connectors/destination-snowflake/gradle.properties +++ b/airbyte-integrations/connectors/destination-snowflake/gradle.properties @@ -1,3 +1,3 @@ testExecutionConcurrency=-1 -cdkVersion=0.1.82 +cdkVersion=0.1.91 JunitMethodExecutionTimeout=10m diff --git a/airbyte-integrations/connectors/destination-snowflake/metadata.yaml b/airbyte-integrations/connectors/destination-snowflake/metadata.yaml index f7ac3c12394..79525a8aa12 100644 --- a/airbyte-integrations/connectors/destination-snowflake/metadata.yaml +++ b/airbyte-integrations/connectors/destination-snowflake/metadata.yaml @@ -6,7 +6,7 @@ data: connectorSubtype: database connectorType: destination definitionId: 424892c4-daac-4491-b35d-c6688ba547ba - dockerImageTag: 4.0.31 + dockerImageTag: 4.0.32-rc.1 dockerRepository: airbyte/destination-snowflake documentationUrl: https://docs.airbyte.com/integrations/destinations/snowflake githubIssueLabel: destination-snowflake @@ -31,6 +31,8 @@ data: enabled: true releaseStage: generally_available releases: + rolloutConfiguration: + enableProgressiveRollout: true breakingChanges: 2.0.0: message: Remove GCS/S3 loading method support. diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeBeanFactory.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeBeanFactory.kt index a7f8b9fdb31..777e0718f50 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeBeanFactory.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeBeanFactory.kt @@ -11,15 +11,18 @@ import io.airbyte.cdk.load.check.CheckOperationV2 import io.airbyte.cdk.load.check.DestinationCheckerV2 import io.airbyte.cdk.load.config.DataChannelMedium import io.airbyte.cdk.load.dataflow.config.AggregatePublishingConfig -import io.airbyte.cdk.load.orchestration.db.DefaultTempTableNameGenerator -import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator +import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator +import io.airbyte.cdk.load.table.TempTableNameGenerator import io.airbyte.cdk.output.OutputConsumer import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier -import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName +import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory import io.airbyte.integrations.destination.snowflake.spec.UsernamePasswordAuthConfiguration +import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRawRecordFormatter +import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter +import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter import io.micronaut.context.annotation.Factory import io.micronaut.context.annotation.Primary import io.micronaut.context.annotation.Requires @@ -204,6 +207,17 @@ class SnowflakeBeanFactory { outputConsumer: OutputConsumer, ) = CheckOperationV2(destinationChecker, outputConsumer) + @Singleton + fun snowflakeRecordFormatter( + snowflakeConfiguration: SnowflakeConfiguration + ): SnowflakeRecordFormatter { + return if (snowflakeConfiguration.legacyRawTablesOnly) { + SnowflakeRawRecordFormatter() + } else { + SnowflakeSchemaRecordFormatter() + } + } + @Singleton fun aggregatePublishingConfig(dataChannelMedium: DataChannelMedium): AggregatePublishingConfig { // NOT speed mode diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/check/SnowflakeChecker.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/check/SnowflakeChecker.kt index 064570c3dd9..8fc644a3941 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/check/SnowflakeChecker.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/check/SnowflakeChecker.kt @@ -13,13 +13,17 @@ import io.airbyte.cdk.load.data.FieldType import io.airbyte.cdk.load.data.ObjectType import io.airbyte.cdk.load.data.StringType import io.airbyte.cdk.load.message.Meta +import io.airbyte.cdk.load.schema.model.ColumnSchema +import io.airbyte.cdk.load.schema.model.StreamTableSchema +import io.airbyte.cdk.load.schema.model.TableName +import io.airbyte.cdk.load.schema.model.TableNames import io.airbyte.cdk.load.table.ColumnNameMapping -import io.airbyte.cdk.load.table.TableName import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient -import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName +import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager +import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration -import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer +import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter import jakarta.inject.Singleton import java.time.OffsetDateTime import java.util.UUID @@ -31,7 +35,7 @@ internal const val CHECK_COLUMN_NAME = "test_key" class SnowflakeChecker( private val snowflakeAirbyteClient: SnowflakeAirbyteClient, private val snowflakeConfiguration: SnowflakeConfiguration, - private val snowflakeColumnUtils: SnowflakeColumnUtils, + private val columnManager: SnowflakeColumnManager, ) : DestinationCheckerV2 { override fun check() { @@ -46,11 +50,40 @@ class SnowflakeChecker( Meta.AirbyteMetaFields.GENERATION_ID.fieldName to AirbyteValue.from(0), CHECK_COLUMN_NAME.toSnowflakeCompatibleName() to AirbyteValue.from("test-value") ) - val outputSchema = snowflakeConfiguration.schema.toSnowflakeCompatibleName() + val outputSchema = + if (snowflakeConfiguration.legacyRawTablesOnly) { + snowflakeConfiguration.schema + } else { + snowflakeConfiguration.schema.toSnowflakeCompatibleName() + } val tableName = "_airbyte_connection_test_${ UUID.randomUUID().toString().replace("-".toRegex(), "")}".toSnowflakeCompatibleName() val qualifiedTableName = TableName(namespace = outputSchema, name = tableName) + val tableSchema = + StreamTableSchema( + tableNames = + TableNames( + finalTableName = qualifiedTableName, + tempTableName = qualifiedTableName + ), + columnSchema = + ColumnSchema( + inputToFinalColumnNames = + mapOf( + CHECK_COLUMN_NAME to CHECK_COLUMN_NAME.toSnowflakeCompatibleName() + ), + finalSchema = + mapOf( + CHECK_COLUMN_NAME.toSnowflakeCompatibleName() to + io.airbyte.cdk.load.component.ColumnType("VARCHAR", false) + ), + inputSchema = + mapOf(CHECK_COLUMN_NAME to FieldType(StringType, nullable = false)) + ), + importType = Append + ) + val destinationStream = DestinationStream( unmappedNamespace = outputSchema, @@ -63,7 +96,8 @@ class SnowflakeChecker( generationId = 0L, minimumGenerationId = 0L, syncId = 0L, - namespaceMapper = NamespaceMapper() + namespaceMapper = NamespaceMapper(), + tableSchema = tableSchema ) runBlocking { try { @@ -75,14 +109,14 @@ class SnowflakeChecker( replace = true, ) - val columns = snowflakeAirbyteClient.describeTable(qualifiedTableName) val snowflakeInsertBuffer = SnowflakeInsertBuffer( tableName = qualifiedTableName, - columns = columns, snowflakeClient = snowflakeAirbyteClient, snowflakeConfiguration = snowflakeConfiguration, - snowflakeColumnUtils = snowflakeColumnUtils, + columnSchema = tableSchema.columnSchema, + columnManager = columnManager, + snowflakeRecordFormatter = SnowflakeSchemaRecordFormatter(), ) snowflakeInsertBuffer.accumulate(data) diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/client/SnowflakeAirbyteClient.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/client/SnowflakeAirbyteClient.kt index 7973d6bee4c..0cfafbb2bef 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/client/SnowflakeAirbyteClient.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/client/SnowflakeAirbyteClient.kt @@ -13,18 +13,16 @@ import io.airbyte.cdk.load.component.TableColumns import io.airbyte.cdk.load.component.TableOperationsClient import io.airbyte.cdk.load.component.TableSchema import io.airbyte.cdk.load.component.TableSchemaEvolutionClient +import io.airbyte.cdk.load.schema.model.TableName import io.airbyte.cdk.load.table.ColumnNameMapping -import io.airbyte.cdk.load.table.TableName import io.airbyte.cdk.load.util.deserializeToNode -import io.airbyte.integrations.destination.snowflake.db.ColumnDefinition -import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier -import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName +import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager +import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS -import io.airbyte.integrations.destination.snowflake.sql.NOT_NULL -import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator import io.airbyte.integrations.destination.snowflake.sql.andLog +import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier import io.github.oshai.kotlinlogging.KotlinLogging import jakarta.inject.Singleton import java.sql.ResultSet @@ -41,13 +39,10 @@ private val log = KotlinLogging.logger {} class SnowflakeAirbyteClient( private val dataSource: DataSource, private val sqlGenerator: SnowflakeDirectLoadSqlGenerator, - private val snowflakeColumnUtils: SnowflakeColumnUtils, private val snowflakeConfiguration: SnowflakeConfiguration, + private val columnManager: SnowflakeColumnManager, ) : TableOperationsClient, TableSchemaEvolutionClient { - private val airbyteColumnNames = - snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet() - override suspend fun countTable(tableName: TableName): Long? = try { dataSource.connection.use { connection -> @@ -126,7 +121,7 @@ class SnowflakeAirbyteClient( columnNameMapping: ColumnNameMapping, replace: Boolean ) { - execute(sqlGenerator.createTable(stream, tableName, columnNameMapping, replace)) + execute(sqlGenerator.createTable(tableName, stream.tableSchema, replace)) execute(sqlGenerator.createSnowflakeStage(tableName)) } @@ -163,7 +158,15 @@ class SnowflakeAirbyteClient( sourceTableName: TableName, targetTableName: TableName ) { - execute(sqlGenerator.copyTable(columnNameMapping, sourceTableName, targetTableName)) + // Get all column names from the mapping (both meta columns and user columns) + val columnNames = buildSet { + // Add Airbyte meta columns (using uppercase constants) + addAll(columnManager.getMetaColumnNames()) + // Add user columns from mapping + addAll(columnNameMapping.values) + } + + execute(sqlGenerator.copyTable(columnNames, sourceTableName, targetTableName)) } override suspend fun upsertTable( @@ -172,9 +175,7 @@ class SnowflakeAirbyteClient( sourceTableName: TableName, targetTableName: TableName ) { - execute( - sqlGenerator.upsertTable(stream, columnNameMapping, sourceTableName, targetTableName) - ) + execute(sqlGenerator.upsertTable(stream.tableSchema, sourceTableName, targetTableName)) } override suspend fun dropTable(tableName: TableName) { @@ -206,7 +207,7 @@ class SnowflakeAirbyteClient( stream: DestinationStream, columnNameMapping: ColumnNameMapping ): TableSchema { - return TableSchema(getColumnsFromStream(stream, columnNameMapping)) + return TableSchema(stream.tableSchema.columnSchema.finalSchema) } override suspend fun applyChangeset( @@ -253,7 +254,7 @@ class SnowflakeAirbyteClient( val columnName = escapeJsonIdentifier(rs.getString("name")) // Filter out airbyte columns - if (airbyteColumnNames.contains(columnName)) { + if (columnManager.getMetaColumnNames().contains(columnName)) { continue } val dataType = rs.getString("type").takeWhile { char -> char != '(' } @@ -271,49 +272,6 @@ class SnowflakeAirbyteClient( } } - internal fun getColumnsFromStream( - stream: DestinationStream, - columnNameMapping: ColumnNameMapping - ): Map = - snowflakeColumnUtils - .columnsAndTypes(stream.schema.asColumns(), columnNameMapping) - .filter { column -> column.columnName !in airbyteColumnNames } - .associate { column -> - // columnsAndTypes returns types as either `FOO` or `FOO NOT NULL`. - // so check for that suffix. - val nullable = !column.columnType.endsWith(NOT_NULL) - val type = - column.columnType - .takeWhile { char -> - // This is to remove any precision parts of the dialect type - char != '(' - } - .removeSuffix(NOT_NULL) - .trim() - - column.columnName to ColumnType(type, nullable) - } - - internal fun generateSchemaChanges( - columnsInDb: Set, - columnsInStream: Set - ): Triple, Set, Set> { - val addedColumns = - columnsInStream.filter { it.name !in columnsInDb.map { col -> col.name } }.toSet() - val deletedColumns = - columnsInDb.filter { it.name !in columnsInStream.map { col -> col.name } }.toSet() - val commonColumns = - columnsInStream.filter { it.name in columnsInDb.map { col -> col.name } }.toSet() - val modifiedColumns = - commonColumns - .filter { - val dbType = columnsInDb.find { column -> it.name == column.name }?.type - it.type != dbType - } - .toSet() - return Triple(addedColumns, deletedColumns, modifiedColumns) - } - override suspend fun getGenerationId(tableName: TableName): Long = try { dataSource.connection.use { connection -> @@ -326,7 +284,7 @@ class SnowflakeAirbyteClient( * format. In order to make sure these strings will match any column names * that we have formatted in-memory, re-apply the escaping. */ - resultSet.getLong(snowflakeColumnUtils.getGenerationIdColumnName()) + resultSet.getLong(columnManager.getGenerationIdColumnName()) } else { log.warn { "No generation ID found for table ${tableName.toPrettyString()}, returning 0" @@ -351,8 +309,8 @@ class SnowflakeAirbyteClient( execute(sqlGenerator.putInStage(tableName, tempFilePath)) } - fun copyFromStage(tableName: TableName, filename: String) { - execute(sqlGenerator.copyFromStage(tableName, filename)) + fun copyFromStage(tableName: TableName, filename: String, columnNames: List) { + execute(sqlGenerator.copyFromStage(tableName, filename, columnNames)) } fun describeTable(tableName: TableName): LinkedHashMap = diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/dataflow/SnowflakeAggregateFactory.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/dataflow/SnowflakeAggregateFactory.kt index 340c80f319f..70807de02d8 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/dataflow/SnowflakeAggregateFactory.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/dataflow/SnowflakeAggregateFactory.kt @@ -4,47 +4,41 @@ package io.airbyte.integrations.destination.snowflake.dataflow +import io.airbyte.cdk.load.command.DestinationCatalog import io.airbyte.cdk.load.dataflow.aggregate.Aggregate import io.airbyte.cdk.load.dataflow.aggregate.AggregateFactory import io.airbyte.cdk.load.dataflow.aggregate.StoreKey -import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig -import io.airbyte.cdk.load.table.TableName +import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig import io.airbyte.cdk.load.write.StreamStateStore import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient +import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration -import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer -import io.micronaut.cache.annotation.CacheConfig -import io.micronaut.cache.annotation.Cacheable +import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter import jakarta.inject.Singleton @Singleton -@CacheConfig("table-columns") -// class has to be open to make the cache stuff work -open class SnowflakeAggregateFactory( +class SnowflakeAggregateFactory( private val snowflakeClient: SnowflakeAirbyteClient, private val streamStateStore: StreamStateStore, private val snowflakeConfiguration: SnowflakeConfiguration, - private val snowflakeColumnUtils: SnowflakeColumnUtils, + private val catalog: DestinationCatalog, + private val columnManager: SnowflakeColumnManager, + private val snowflakeRecordFormatter: SnowflakeRecordFormatter, ) : AggregateFactory { override fun create(key: StoreKey): Aggregate { + val stream = catalog.getStream(key) val tableName = streamStateStore.get(key)!!.tableName + val buffer = SnowflakeInsertBuffer( tableName = tableName, - columns = getTableColumns(tableName), snowflakeClient = snowflakeClient, snowflakeConfiguration = snowflakeConfiguration, - snowflakeColumnUtils = snowflakeColumnUtils, + columnSchema = stream.tableSchema.columnSchema, + columnManager = columnManager, + snowflakeRecordFormatter = snowflakeRecordFormatter, ) return SnowflakeAggregate(buffer = buffer) } - - // We assume that a table isn't getting altered _during_ a sync. - // This allows us to only SHOW COLUMNS once per table per sync, - // rather than refetching it on every aggregate. - @Cacheable - // function has to be open to make caching work - internal open fun getTableColumns(tableName: TableName) = - snowflakeClient.describeTable(tableName) } diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/db/ColumnDefinition.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/db/ColumnDefinition.kt deleted file mode 100644 index 567722299c7..00000000000 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/db/ColumnDefinition.kt +++ /dev/null @@ -1,13 +0,0 @@ -/* - * Copyright (c) 2025 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.integrations.destination.snowflake.db - -/** - * Jdbc destination column definition representation - * - * @param name - * @param type - */ -data class ColumnDefinition(val name: String, val type: String) diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/db/SnowflakeDirectLoadDatabaseInitialStatusGatherer.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/db/SnowflakeDirectLoadDatabaseInitialStatusGatherer.kt index 75d39b223a5..1520585c166 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/db/SnowflakeDirectLoadDatabaseInitialStatusGatherer.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/db/SnowflakeDirectLoadDatabaseInitialStatusGatherer.kt @@ -4,17 +4,17 @@ package io.airbyte.integrations.destination.snowflake.db +import io.airbyte.cdk.load.command.DestinationCatalog import io.airbyte.cdk.load.component.TableOperationsClient -import io.airbyte.cdk.load.orchestration.db.BaseDirectLoadInitialStatusGatherer -import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator +import io.airbyte.cdk.load.table.BaseDirectLoadInitialStatusGatherer import jakarta.inject.Singleton @Singleton class SnowflakeDirectLoadDatabaseInitialStatusGatherer( tableOperationsClient: TableOperationsClient, - tempTableNameGenerator: TempTableNameGenerator, + catalog: DestinationCatalog, ) : BaseDirectLoadInitialStatusGatherer( tableOperationsClient, - tempTableNameGenerator, + catalog, ) diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/db/SnowflakeNameGenerators.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/db/SnowflakeNameGenerators.kt deleted file mode 100644 index 5c410fec706..00000000000 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/db/SnowflakeNameGenerators.kt +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Copyright (c) 2025 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.integrations.destination.snowflake.db - -import io.airbyte.cdk.ConfigErrorException -import io.airbyte.cdk.load.command.DestinationStream -import io.airbyte.cdk.load.orchestration.db.ColumnNameGenerator -import io.airbyte.cdk.load.orchestration.db.FinalTableNameGenerator -import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TypingDedupingUtil -import io.airbyte.cdk.load.table.TableName -import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration -import io.airbyte.integrations.destination.snowflake.sql.QUOTE -import jakarta.inject.Singleton - -@Singleton -class SnowflakeFinalTableNameGenerator(private val config: SnowflakeConfiguration) : - FinalTableNameGenerator { - override fun getTableName(streamDescriptor: DestinationStream.Descriptor): TableName { - val namespace = streamDescriptor.namespace ?: config.schema - return if (!config.legacyRawTablesOnly) { - TableName( - namespace = namespace.toSnowflakeCompatibleName(), - name = streamDescriptor.name.toSnowflakeCompatibleName(), - ) - } else { - TableName( - namespace = config.internalTableSchema, - name = - TypingDedupingUtil.concatenateRawTableName( - namespace = escapeJsonIdentifier(namespace), - name = escapeJsonIdentifier(streamDescriptor.name), - ), - ) - } - } -} - -@Singleton -class SnowflakeColumnNameGenerator(private val config: SnowflakeConfiguration) : - ColumnNameGenerator { - override fun getColumnName(column: String): ColumnNameGenerator.ColumnName { - return if (!config.legacyRawTablesOnly) { - ColumnNameGenerator.ColumnName( - column.toSnowflakeCompatibleName(), - column.toSnowflakeCompatibleName(), - ) - } else { - ColumnNameGenerator.ColumnName( - column, - column, - ) - } - } -} - -/** - * Escapes double-quotes in a JSON identifier by doubling them. This shit is legacy -- I don't know - * why this would be necessary but no harm in keeping it so I am keeping it. - * - * @return The escaped identifier. - */ -fun escapeJsonIdentifier(identifier: String): String { - // Note that we don't need to escape backslashes here! - // The only special character in an identifier is the double-quote, which needs to be - // doubled. - return identifier.replace(QUOTE, "$QUOTE$QUOTE") -} - -/** - * Transforms a string to be compatible with Snowflake table and column names. - * - * @return The transformed string suitable for Snowflake identifiers. - */ -fun String.toSnowflakeCompatibleName(): String { - var identifier = this - - // Handle empty strings - if (identifier.isEmpty()) { - throw ConfigErrorException("Empty string is invalid identifier") - } - - // Snowflake scripting language does something weird when the `${` bigram shows up in the - // script so replace these with something else. - // For completeness, if we trigger this, also replace closing curly braces with underscores. - if (identifier.contains("\${")) { - identifier = identifier.replace("$", "_").replace("{", "_").replace("}", "_") - } - - // Escape double quotes - identifier = escapeJsonIdentifier(identifier) - - return identifier.uppercase() -} diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/schema/SnowflakeColumnManager.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/schema/SnowflakeColumnManager.kt new file mode 100644 index 00000000000..40f855e60fb --- /dev/null +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/schema/SnowflakeColumnManager.kt @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2025 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.integrations.destination.snowflake.schema + +import io.airbyte.cdk.load.component.ColumnType +import io.airbyte.cdk.load.message.Meta +import io.airbyte.cdk.load.schema.model.ColumnSchema +import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration +import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_EXTRACTED_AT +import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_GENERATION_ID +import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_META +import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_RAW_ID +import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDataType +import jakarta.inject.Singleton + +/** + * Manages column names and ordering for Snowflake tables based on whether legacy raw tables mode is + * enabled. + * + * TODO: We should add meta column munging and raw table support to the CDK, so this extra layer of + * management shouldn't be necessary. + */ +@Singleton +class SnowflakeColumnManager( + private val config: SnowflakeConfiguration, +) { + /** + * Get the list of column names for a table in the order they should appear in the CSV file and + * COPY INTO statement. + * + * Warning: MUST match the order defined in SnowflakeRecordFormatter + * + * @param columnSchema The schema containing column information (ignored in raw mode) + * @return List of column names in the correct order + */ + fun getTableColumnNames(columnSchema: ColumnSchema): List { + return buildList { + addAll(getMetaColumnNames()) + addAll(columnSchema.finalSchema.keys) + } + } + + /** + * Get the list of Airbyte meta column names. In schema mode, these are uppercase. In raw mode, + * they are lowercase and included loaded_at + * + * @return Set of meta column names + */ + fun getMetaColumnNames(): Set = + if (config.legacyRawTablesOnly) { + Constants.rawModeMetaColNames + } else { + Constants.schemaModeMetaColNames + } + + /** + * Get the Airbyte meta columns as a map of column name to ColumnType. This provides both the + * column names and their types for table creation. + * + * @param columnSchema The user column schema (used to check for CDC columns) + * @return Map of meta column names to their types + */ + fun getMetaColumns(): LinkedHashMap { + return if (config.legacyRawTablesOnly) { + Constants.rawModeMetaColumns + } else { + Constants.schemaModeMetaColumns + } + } + + fun getGenerationIdColumnName(): String { + return if (config.legacyRawTablesOnly) { + Meta.COLUMN_NAME_AB_GENERATION_ID + } else { + SNOWFLAKE_AB_GENERATION_ID + } + } + + object Constants { + val rawModeMetaColumns = + linkedMapOf( + Meta.COLUMN_NAME_AB_RAW_ID to ColumnType(SnowflakeDataType.VARCHAR.typeName, false), + Meta.COLUMN_NAME_AB_EXTRACTED_AT to + ColumnType( + SnowflakeDataType.TIMESTAMP_TZ.typeName, + false, + ), + Meta.COLUMN_NAME_AB_META to ColumnType(SnowflakeDataType.VARIANT.typeName, false), + Meta.COLUMN_NAME_AB_GENERATION_ID to + ColumnType( + SnowflakeDataType.NUMBER.typeName, + true, + ), + Meta.COLUMN_NAME_AB_LOADED_AT to + ColumnType( + SnowflakeDataType.TIMESTAMP_TZ.typeName, + true, + ), + ) + + val schemaModeMetaColumns = + linkedMapOf( + SNOWFLAKE_AB_RAW_ID to ColumnType(SnowflakeDataType.VARCHAR.typeName, false), + SNOWFLAKE_AB_EXTRACTED_AT to + ColumnType(SnowflakeDataType.TIMESTAMP_TZ.typeName, false), + SNOWFLAKE_AB_META to ColumnType(SnowflakeDataType.VARIANT.typeName, false), + SNOWFLAKE_AB_GENERATION_ID to ColumnType(SnowflakeDataType.NUMBER.typeName, true), + ) + + val rawModeMetaColNames: Set = rawModeMetaColumns.keys + + val schemaModeMetaColNames: Set = schemaModeMetaColumns.keys + } +} diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/schema/SnowflakeNamingUtils.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/schema/SnowflakeNamingUtils.kt new file mode 100644 index 00000000000..da006d7fdcd --- /dev/null +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/schema/SnowflakeNamingUtils.kt @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2025 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.integrations.destination.snowflake.schema + +import io.airbyte.cdk.ConfigErrorException +import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier + +/** + * Transforms a string to be compatible with Snowflake table and column names. + * + * @return The transformed string suitable for Snowflake identifiers. + */ +fun String.toSnowflakeCompatibleName(): String { + var identifier = this + + // Handle empty strings + if (identifier.isEmpty()) { + throw ConfigErrorException("Empty string is invalid identifier") + } + + // Snowflake scripting language does something weird when the `${` bigram shows up in the + // script so replace these with something else. + // For completeness, if we trigger this, also replace closing curly braces with underscores. + if (identifier.contains("\${")) { + identifier = identifier.replace("$", "_").replace("{", "_").replace("}", "_") + } + + // Escape double quotes + identifier = escapeJsonIdentifier(identifier) + + return identifier.uppercase() +} diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/schema/SnowflakeTableSchemaMapper.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/schema/SnowflakeTableSchemaMapper.kt new file mode 100644 index 00000000000..e8b9c1a5acc --- /dev/null +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/schema/SnowflakeTableSchemaMapper.kt @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2025 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.integrations.destination.snowflake.schema + +import io.airbyte.cdk.load.command.DestinationStream +import io.airbyte.cdk.load.component.ColumnType +import io.airbyte.cdk.load.data.ArrayType +import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema +import io.airbyte.cdk.load.data.BooleanType +import io.airbyte.cdk.load.data.DateType +import io.airbyte.cdk.load.data.FieldType +import io.airbyte.cdk.load.data.IntegerType +import io.airbyte.cdk.load.data.NumberType +import io.airbyte.cdk.load.data.ObjectType +import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema +import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema +import io.airbyte.cdk.load.data.StringType +import io.airbyte.cdk.load.data.TimeTypeWithTimezone +import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone +import io.airbyte.cdk.load.data.TimestampTypeWithTimezone +import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone +import io.airbyte.cdk.load.data.UnionType +import io.airbyte.cdk.load.data.UnknownType +import io.airbyte.cdk.load.message.Meta +import io.airbyte.cdk.load.schema.TableSchemaMapper +import io.airbyte.cdk.load.schema.model.StreamTableSchema +import io.airbyte.cdk.load.schema.model.TableName +import io.airbyte.cdk.load.table.TempTableNameGenerator +import io.airbyte.cdk.load.table.TypingDedupingUtil +import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration +import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDataType +import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier +import jakarta.inject.Singleton + +@Singleton +class SnowflakeTableSchemaMapper( + private val config: SnowflakeConfiguration, + private val tempTableNameGenerator: TempTableNameGenerator, +) : TableSchemaMapper { + override fun toFinalTableName(desc: DestinationStream.Descriptor): TableName { + val namespace = desc.namespace ?: config.schema + return if (!config.legacyRawTablesOnly) { + TableName( + namespace = namespace.toSnowflakeCompatibleName(), + name = desc.name.toSnowflakeCompatibleName(), + ) + } else { + TableName( + namespace = config.internalTableSchema, + name = + TypingDedupingUtil.concatenateRawTableName( + namespace = escapeJsonIdentifier(namespace), + name = escapeJsonIdentifier(desc.name), + ), + ) + } + } + + override fun toTempTableName(tableName: TableName): TableName { + return tempTableNameGenerator.generate(tableName) + } + + override fun toColumnName(name: String): String { + return if (!config.legacyRawTablesOnly) { + name.toSnowflakeCompatibleName() + } else { + // In legacy mode, column names are not transformed + name + } + } + + override fun toColumnType(fieldType: FieldType): ColumnType { + val snowflakeType = + when (fieldType.type) { + // Simple types + BooleanType -> SnowflakeDataType.BOOLEAN.typeName + IntegerType -> SnowflakeDataType.NUMBER.typeName + NumberType -> SnowflakeDataType.FLOAT.typeName + StringType -> SnowflakeDataType.VARCHAR.typeName + + // Temporal types + DateType -> SnowflakeDataType.DATE.typeName + TimeTypeWithTimezone -> SnowflakeDataType.VARCHAR.typeName + TimeTypeWithoutTimezone -> SnowflakeDataType.TIME.typeName + TimestampTypeWithTimezone -> SnowflakeDataType.TIMESTAMP_TZ.typeName + TimestampTypeWithoutTimezone -> SnowflakeDataType.TIMESTAMP_NTZ.typeName + + // Semistructured types + is ArrayType, + ArrayTypeWithoutSchema -> SnowflakeDataType.ARRAY.typeName + is ObjectType, + ObjectTypeWithEmptySchema, + ObjectTypeWithoutSchema -> SnowflakeDataType.OBJECT.typeName + is UnionType -> SnowflakeDataType.VARIANT.typeName + is UnknownType -> SnowflakeDataType.VARIANT.typeName + } + + return ColumnType(snowflakeType, fieldType.nullable) + } + + override fun toFinalSchema(tableSchema: StreamTableSchema): StreamTableSchema { + if (!config.legacyRawTablesOnly) { + return tableSchema + } + + return StreamTableSchema( + tableNames = tableSchema.tableNames, + columnSchema = + tableSchema.columnSchema.copy( + finalSchema = + mapOf( + Meta.COLUMN_NAME_DATA to + ColumnType(SnowflakeDataType.OBJECT.typeName, false) + ) + ), + importType = tableSchema.importType, + ) + } +} diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeColumnUtils.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeColumnUtils.kt deleted file mode 100644 index a97b046c9af..00000000000 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeColumnUtils.kt +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Copyright (c) 2025 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.integrations.destination.snowflake.sql - -import com.google.common.annotations.VisibleForTesting -import io.airbyte.cdk.load.data.AirbyteType -import io.airbyte.cdk.load.data.ArrayType -import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema -import io.airbyte.cdk.load.data.BooleanType -import io.airbyte.cdk.load.data.DateType -import io.airbyte.cdk.load.data.FieldType -import io.airbyte.cdk.load.data.IntegerType -import io.airbyte.cdk.load.data.NumberType -import io.airbyte.cdk.load.data.ObjectType -import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema -import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema -import io.airbyte.cdk.load.data.StringType -import io.airbyte.cdk.load.data.TimeTypeWithTimezone -import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone -import io.airbyte.cdk.load.data.TimestampTypeWithTimezone -import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone -import io.airbyte.cdk.load.data.UnionType -import io.airbyte.cdk.load.data.UnknownType -import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT -import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID -import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT -import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META -import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID -import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA -import io.airbyte.cdk.load.table.ColumnNameMapping -import io.airbyte.integrations.destination.snowflake.db.SnowflakeColumnNameGenerator -import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName -import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration -import jakarta.inject.Singleton -import kotlin.collections.component1 -import kotlin.collections.component2 -import kotlin.collections.joinToString -import kotlin.collections.map -import kotlin.collections.plus - -internal const val NOT_NULL = "NOT NULL" - -internal val DEFAULT_COLUMNS = - listOf( - ColumnAndType( - columnName = COLUMN_NAME_AB_RAW_ID, - columnType = "${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL" - ), - ColumnAndType( - columnName = COLUMN_NAME_AB_EXTRACTED_AT, - columnType = "${SnowflakeDataType.TIMESTAMP_TZ.typeName} $NOT_NULL" - ), - ColumnAndType( - columnName = COLUMN_NAME_AB_META, - columnType = "${SnowflakeDataType.VARIANT.typeName} $NOT_NULL" - ), - ColumnAndType( - columnName = COLUMN_NAME_AB_GENERATION_ID, - columnType = SnowflakeDataType.NUMBER.typeName - ), - ) - -internal val RAW_DATA_COLUMN = - ColumnAndType( - columnName = COLUMN_NAME_DATA, - columnType = "${SnowflakeDataType.VARIANT.typeName} $NOT_NULL" - ) - -internal val RAW_COLUMNS = - listOf( - ColumnAndType( - columnName = COLUMN_NAME_AB_LOADED_AT, - columnType = SnowflakeDataType.TIMESTAMP_TZ.typeName - ), - RAW_DATA_COLUMN - ) - -@Singleton -class SnowflakeColumnUtils( - private val snowflakeConfiguration: SnowflakeConfiguration, - private val snowflakeColumnNameGenerator: SnowflakeColumnNameGenerator, -) { - - @VisibleForTesting - internal fun defaultColumns(): List = - if (snowflakeConfiguration.legacyRawTablesOnly) { - DEFAULT_COLUMNS + RAW_COLUMNS - } else { - DEFAULT_COLUMNS - } - - internal fun formattedDefaultColumns(): List = - defaultColumns().map { - ColumnAndType( - columnName = formatColumnName(it.columnName, false), - columnType = it.columnType, - ) - } - - fun getGenerationIdColumnName(): String { - return if (snowflakeConfiguration.legacyRawTablesOnly) { - COLUMN_NAME_AB_GENERATION_ID - } else { - COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName() - } - } - - fun getColumnNames(columnNameMapping: ColumnNameMapping): String = - if (snowflakeConfiguration.legacyRawTablesOnly) { - getFormattedDefaultColumnNames(true).joinToString(",") - } else { - (getFormattedDefaultColumnNames(true) + - columnNameMapping.map { (_, actualName) -> actualName.quote() }) - .joinToString(",") - } - - fun getFormattedDefaultColumnNames(quote: Boolean = false): List = - defaultColumns().map { formatColumnName(it.columnName, quote) } - - fun getFormattedColumnNames( - columns: Map, - columnNameMapping: ColumnNameMapping, - quote: Boolean = true, - ): List = - if (snowflakeConfiguration.legacyRawTablesOnly) { - getFormattedDefaultColumnNames(quote) - } else { - getFormattedDefaultColumnNames(quote) + - columns.map { (fieldName, _) -> - val columnName = columnNameMapping[fieldName] ?: fieldName - if (quote) columnName.quote() else columnName - } - } - - fun columnsAndTypes( - columns: Map, - columnNameMapping: ColumnNameMapping - ): List = - if (snowflakeConfiguration.legacyRawTablesOnly) { - formattedDefaultColumns() - } else { - formattedDefaultColumns() + - columns.map { (fieldName, type) -> - val columnName = columnNameMapping[fieldName] ?: fieldName - val typeName = toDialectType(type.type) - ColumnAndType( - columnName = columnName, - columnType = if (type.nullable) typeName else "$typeName $NOT_NULL", - ) - } - } - - fun formatColumnName( - columnName: String, - quote: Boolean = true, - ): String { - val formattedColumnName = - if (columnName == COLUMN_NAME_DATA) columnName - else snowflakeColumnNameGenerator.getColumnName(columnName).displayName - return if (quote) formattedColumnName.quote() else formattedColumnName - } - - fun toDialectType(type: AirbyteType): String = - when (type) { - // Simple types - BooleanType -> SnowflakeDataType.BOOLEAN.typeName - IntegerType -> SnowflakeDataType.NUMBER.typeName - NumberType -> SnowflakeDataType.FLOAT.typeName - StringType -> SnowflakeDataType.VARCHAR.typeName - - // Temporal types - DateType -> SnowflakeDataType.DATE.typeName - TimeTypeWithTimezone -> SnowflakeDataType.VARCHAR.typeName - TimeTypeWithoutTimezone -> SnowflakeDataType.TIME.typeName - TimestampTypeWithTimezone -> SnowflakeDataType.TIMESTAMP_TZ.typeName - TimestampTypeWithoutTimezone -> SnowflakeDataType.TIMESTAMP_NTZ.typeName - - // Semistructured types - is ArrayType, - ArrayTypeWithoutSchema -> SnowflakeDataType.ARRAY.typeName - is ObjectType, - ObjectTypeWithEmptySchema, - ObjectTypeWithoutSchema -> SnowflakeDataType.OBJECT.typeName - is UnionType -> SnowflakeDataType.VARIANT.typeName - is UnknownType -> SnowflakeDataType.VARIANT.typeName - } -} - -data class ColumnAndType(val columnName: String, val columnType: String) { - override fun toString(): String { - return "${columnName.quote()} $columnType" - } -} - -/** - * Surrounds the string instance with double quotation marks (e.g. "some string" -> "\"some - * string\""). - */ -fun String.quote() = "$QUOTE$this$QUOTE" diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeDataType.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeDataType.kt index 32b5f7c2c99..d5c4f784e13 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeDataType.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeDataType.kt @@ -10,7 +10,7 @@ package io.airbyte.integrations.destination.snowflake.sql */ enum class SnowflakeDataType(val typeName: String) { // Numeric types - NUMBER("NUMBER(38,0)"), + NUMBER("NUMBER"), FLOAT("FLOAT"), // String & binary types diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeDirectLoadSqlGenerator.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeDirectLoadSqlGenerator.kt index 1e2511cceef..533d81abe1c 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeDirectLoadSqlGenerator.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeDirectLoadSqlGenerator.kt @@ -4,16 +4,19 @@ package io.airbyte.integrations.destination.snowflake.sql -import io.airbyte.cdk.load.command.Dedupe -import io.airbyte.cdk.load.command.DestinationStream +import com.google.common.annotations.VisibleForTesting import io.airbyte.cdk.load.component.ColumnType import io.airbyte.cdk.load.component.ColumnTypeChange import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT +import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID +import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META +import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID +import io.airbyte.cdk.load.schema.model.StreamTableSchema +import io.airbyte.cdk.load.schema.model.TableName import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN -import io.airbyte.cdk.load.table.ColumnNameMapping -import io.airbyte.cdk.load.table.TableName import io.airbyte.cdk.load.util.UUIDGenerator -import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName +import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager +import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.write.load.CSV_FIELD_SEPARATOR @@ -22,6 +25,14 @@ import io.github.oshai.kotlinlogging.KotlinLogging import jakarta.inject.Singleton internal const val COUNT_TOTAL_ALIAS = "TOTAL" +internal const val NOT_NULL = "NOT NULL" + +// Snowflake-compatible (uppercase) versions of the Airbyte meta column names +internal val SNOWFLAKE_AB_RAW_ID = COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName() +internal val SNOWFLAKE_AB_EXTRACTED_AT = COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName() +internal val SNOWFLAKE_AB_META = COLUMN_NAME_AB_META.toSnowflakeCompatibleName() +internal val SNOWFLAKE_AB_GENERATION_ID = COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName() +internal val SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN = CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName() private val log = KotlinLogging.logger {} @@ -36,80 +47,91 @@ fun String.andLog(): String { @Singleton class SnowflakeDirectLoadSqlGenerator( - private val columnUtils: SnowflakeColumnUtils, private val uuidGenerator: UUIDGenerator, - private val snowflakeConfiguration: SnowflakeConfiguration, - private val snowflakeSqlNameUtils: SnowflakeSqlNameUtils, + private val config: SnowflakeConfiguration, + private val columnManager: SnowflakeColumnManager, ) { fun countTable(tableName: TableName): String { - return "SELECT COUNT(*) AS $COUNT_TOTAL_ALIAS FROM ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog() + return "SELECT COUNT(*) AS $COUNT_TOTAL_ALIAS FROM ${fullyQualifiedName(tableName)}".andLog() } fun createNamespace(namespace: String): String { - return "CREATE SCHEMA IF NOT EXISTS ${snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)}".andLog() + return "CREATE SCHEMA IF NOT EXISTS ${fullyQualifiedNamespace(namespace)}".andLog() } fun createTable( - stream: DestinationStream, tableName: TableName, - columnNameMapping: ColumnNameMapping, + tableSchema: StreamTableSchema, replace: Boolean ): String { + val finalSchema = tableSchema.columnSchema.finalSchema + val metaColumns = columnManager.getMetaColumns() + + // Build column declarations from the meta columns and user schema val columnDeclarations = - columnUtils - .columnsAndTypes(stream.schema.asColumns(), columnNameMapping) - .joinToString(",\n") + buildList { + // Add Airbyte meta columns from the column manager + metaColumns.forEach { (columnName, columnType) -> + val nullability = if (columnType.nullable) "" else " NOT NULL" + add("${columnName.quote()} ${columnType.type}$nullability") + } + + // Add user columns from the munged schema + finalSchema.forEach { (columnName, columnType) -> + val nullability = if (columnType.nullable) "" else " NOT NULL" + add("${columnName.quote()} ${columnType.type}$nullability") + } + } + .joinToString(",\n ") // Snowflake supports CREATE OR REPLACE TABLE, which is simpler than drop+recreate val createOrReplace = if (replace) "CREATE OR REPLACE" else "CREATE" val createTableStatement = """ - $createOrReplace TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)} ( - $columnDeclarations - ) - """.trimIndent() + |$createOrReplace TABLE ${fullyQualifiedName(tableName)} ( + | $columnDeclarations + |) + """.trimMargin() // Something was tripping up trimIndent so we opt for trimMargin return createTableStatement.andLog() } fun showColumns(tableName: TableName): String = - "SHOW COLUMNS IN TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog() + "SHOW COLUMNS IN TABLE ${fullyQualifiedName(tableName)}".andLog() fun copyTable( - columnNameMapping: ColumnNameMapping, + columnNames: Set, sourceTableName: TableName, targetTableName: TableName ): String { - val columnNames = columnUtils.getColumnNames(columnNameMapping) + val columnList = columnNames.joinToString(", ") { it.quote() } return """ - INSERT INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} + INSERT INTO ${fullyQualifiedName(targetTableName)} ( - $columnNames + $columnList ) SELECT - $columnNames - FROM ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} + $columnList + FROM ${fullyQualifiedName(sourceTableName)} """ .trimIndent() .andLog() } fun upsertTable( - stream: DestinationStream, - columnNameMapping: ColumnNameMapping, + tableSchema: StreamTableSchema, sourceTableName: TableName, targetTableName: TableName ): String { - val importType = stream.importType as Dedupe + val finalSchema = tableSchema.columnSchema.finalSchema // Build primary key matching condition + val pks = tableSchema.getPrimaryKey().flatten() val pkEquivalent = - if (importType.primaryKey.isNotEmpty()) { - importType.primaryKey.joinToString(" AND ") { fieldPath -> - val fieldName = fieldPath.first() - val columnName = columnNameMapping[fieldName] ?: fieldName + if (pks.isNotEmpty()) { + pks.joinToString(" AND ") { columnName -> val targetTableColumnName = "target_table.${columnName.quote()}" val newRecordColumnName = "new_record.${columnName.quote()}" """($targetTableColumnName = $newRecordColumnName OR ($targetTableColumnName IS NULL AND $newRecordColumnName IS NULL))""" @@ -120,80 +142,62 @@ class SnowflakeDirectLoadSqlGenerator( } // Build column lists for INSERT and UPDATE - val columnList: String = - columnUtils - .getFormattedColumnNames( - columns = stream.schema.asColumns(), - columnNameMapping = columnNameMapping, - quote = false, - ) - .joinToString( - ",\n", - ) { - it.quote() - } + val allColumns = buildList { + add(SNOWFLAKE_AB_RAW_ID) + add(SNOWFLAKE_AB_EXTRACTED_AT) + add(SNOWFLAKE_AB_META) + add(SNOWFLAKE_AB_GENERATION_ID) + addAll(finalSchema.keys) + } + val columnList: String = allColumns.joinToString(",\n ") { it.quote() } val newRecordColumnList: String = - columnUtils - .getFormattedColumnNames( - columns = stream.schema.asColumns(), - columnNameMapping = columnNameMapping, - quote = false, - ) - .joinToString(",\n") { "new_record.${it.quote()}" } + allColumns.joinToString(",\n ") { "new_record.${it.quote()}" } // Get deduped records from source - val selectSourceRecords = selectDedupedRecords(stream, sourceTableName, columnNameMapping) + val selectSourceRecords = selectDedupedRecords(tableSchema, sourceTableName) // Build cursor comparison for determining which record is newer val cursorComparison: String - if (importType.cursor.isNotEmpty()) { - val cursorFieldName = importType.cursor.first() - val cursor = (columnNameMapping[cursorFieldName] ?: cursorFieldName) + val cursor = tableSchema.getCursor().firstOrNull() + if (cursor != null) { val targetTableCursor = "target_table.${cursor.quote()}" val newRecordCursor = "new_record.${cursor.quote()}" cursorComparison = """ ( $targetTableCursor < $newRecordCursor - OR ($targetTableCursor = $newRecordCursor AND target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}") - OR ($targetTableCursor IS NULL AND $newRecordCursor IS NULL AND target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}") + OR ($targetTableCursor = $newRecordCursor AND target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT") + OR ($targetTableCursor IS NULL AND $newRecordCursor IS NULL AND target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT") OR ($targetTableCursor IS NULL AND $newRecordCursor IS $NOT_NULL) ) """.trimIndent() } else { // No cursor - use extraction timestamp only cursorComparison = - """target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}"""" + """target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT"""" } // Build column assignments for UPDATE val columnAssignments: String = - columnUtils - .getFormattedColumnNames( - columns = stream.schema.asColumns(), - columnNameMapping = columnNameMapping, - quote = false, - ) - .joinToString(",\n") { column -> - "${column.quote()} = new_record.${column.quote()}" - } + allColumns.joinToString(",\n ") { column -> + "${column.quote()} = new_record.${column.quote()}" + } // Handle CDC deletions based on mode val cdcDeleteClause: String val cdcSkipInsertClause: String if ( - stream.schema.asColumns().containsKey(CDC_DELETED_AT_COLUMN) && - snowflakeConfiguration.cdcDeletionMode == CdcDeletionMode.HARD_DELETE + finalSchema.containsKey(SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN) && + config.cdcDeletionMode == CdcDeletionMode.HARD_DELETE ) { // Execute CDC deletions if there's already a record cdcDeleteClause = - "WHEN MATCHED AND new_record.\"${CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()}\" IS NOT NULL AND $cursorComparison THEN DELETE" + "WHEN MATCHED AND new_record.\"${SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN}\" IS NOT NULL AND $cursorComparison THEN DELETE" // And skip insertion entirely if there's no matching record. // (This is possible if a single T+D batch contains both an insertion and deletion for // the same PK) - cdcSkipInsertClause = - "AND new_record.\"${CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()}\" IS NULL" + cdcSkipInsertClause = "AND new_record.\"${SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN}\" IS NULL" } else { cdcDeleteClause = "" cdcSkipInsertClause = "" @@ -203,35 +207,35 @@ class SnowflakeDirectLoadSqlGenerator( val mergeStatement = if (cdcDeleteClause.isNotEmpty()) { """ - MERGE INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} AS target_table - USING ( - $selectSourceRecords - ) AS new_record - ON $pkEquivalent - $cdcDeleteClause - WHEN MATCHED AND $cursorComparison THEN UPDATE SET - $columnAssignments - WHEN NOT MATCHED $cdcSkipInsertClause THEN INSERT ( - $columnList - ) VALUES ( - $newRecordColumnList - ) - """.trimIndent() + |MERGE INTO ${fullyQualifiedName(targetTableName)} AS target_table + |USING ( + |$selectSourceRecords + |) AS new_record + |ON $pkEquivalent + |$cdcDeleteClause + |WHEN MATCHED AND $cursorComparison THEN UPDATE SET + | $columnAssignments + |WHEN NOT MATCHED $cdcSkipInsertClause THEN INSERT ( + | $columnList + |) VALUES ( + | $newRecordColumnList + |) + """.trimMargin() } else { """ - MERGE INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} AS target_table - USING ( - $selectSourceRecords - ) AS new_record - ON $pkEquivalent - WHEN MATCHED AND $cursorComparison THEN UPDATE SET - $columnAssignments - WHEN NOT MATCHED THEN INSERT ( - $columnList - ) VALUES ( - $newRecordColumnList - ) - """.trimIndent() + |MERGE INTO ${fullyQualifiedName(targetTableName)} AS target_table + |USING ( + |$selectSourceRecords + |) AS new_record + |ON $pkEquivalent + |WHEN MATCHED AND $cursorComparison THEN UPDATE SET + | $columnAssignments + |WHEN NOT MATCHED THEN INSERT ( + | $columnList + |) VALUES ( + | $newRecordColumnList + |) + """.trimMargin() } return mergeStatement.andLog() @@ -242,75 +246,66 @@ class SnowflakeDirectLoadSqlGenerator( * table. Uses ROW_NUMBER() window function to select the most recent record per primary key. */ private fun selectDedupedRecords( - stream: DestinationStream, - sourceTableName: TableName, - columnNameMapping: ColumnNameMapping + tableSchema: StreamTableSchema, + sourceTableName: TableName ): String { - val columnList: String = - columnUtils - .getFormattedColumnNames( - columns = stream.schema.asColumns(), - columnNameMapping = columnNameMapping, - quote = false, - ) - .joinToString( - ",\n", - ) { - it.quote() - } - val importType = stream.importType as Dedupe + val allColumns = buildList { + add(SNOWFLAKE_AB_RAW_ID) + add(SNOWFLAKE_AB_EXTRACTED_AT) + add(SNOWFLAKE_AB_META) + add(SNOWFLAKE_AB_GENERATION_ID) + addAll(tableSchema.columnSchema.finalSchema.keys) + } + val columnList: String = allColumns.joinToString(",\n ") { it.quote() } // Build the primary key list for partitioning + val pks = tableSchema.getPrimaryKey().flatten() val pkList = - if (importType.primaryKey.isNotEmpty()) { - importType.primaryKey.joinToString(",") { fieldPath -> - (columnNameMapping[fieldPath.first()] ?: fieldPath.first()).quote() - } + if (pks.isNotEmpty()) { + pks.joinToString(",") { it.quote() } } else { // Should not happen as we check this earlier, but handle it defensively throw IllegalArgumentException("Cannot deduplicate without primary key") } // Build cursor order clause for sorting within each partition + val cursor = tableSchema.getCursor().firstOrNull() val cursorOrderClause = - if (importType.cursor.isNotEmpty()) { - val columnName = - (columnNameMapping[importType.cursor.first()] ?: importType.cursor.first()) - .quote() - "$columnName DESC NULLS LAST," + if (cursor != null) { + "${cursor.quote()} DESC NULLS LAST," } else { "" } return """ - WITH records AS ( - SELECT - $columnList - FROM ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} - ), numbered_rows AS ( - SELECT *, ROW_NUMBER() OVER ( - PARTITION BY $pkList ORDER BY $cursorOrderClause "${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" DESC - ) AS row_number - FROM records - ) - SELECT $columnList - FROM numbered_rows - WHERE row_number = 1 + | WITH records AS ( + | SELECT + | $columnList + | FROM ${fullyQualifiedName(sourceTableName)} + | ), numbered_rows AS ( + | SELECT *, ROW_NUMBER() OVER ( + | PARTITION BY $pkList ORDER BY $cursorOrderClause "$SNOWFLAKE_AB_EXTRACTED_AT" DESC + | ) AS row_number + | FROM records + | ) + | SELECT $columnList + | FROM numbered_rows + | WHERE row_number = 1 """ - .trimIndent() + .trimMargin() .andLog() } fun dropTable(tableName: TableName): String { - return "DROP TABLE IF EXISTS ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog() + return "DROP TABLE IF EXISTS ${fullyQualifiedName(tableName)}".andLog() } fun getGenerationId( tableName: TableName, ): String { return """ - SELECT "${columnUtils.getGenerationIdColumnName()}" - FROM ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)} + SELECT "${columnManager.getGenerationIdColumnName()}" + FROM ${fullyQualifiedName(tableName)} LIMIT 1 """ .trimIndent() @@ -318,12 +313,12 @@ class SnowflakeDirectLoadSqlGenerator( } fun createSnowflakeStage(tableName: TableName): String { - val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName) + val stageName = fullyQualifiedStageName(tableName) return "CREATE STAGE IF NOT EXISTS $stageName".andLog() } fun putInStage(tableName: TableName, tempFilePath: String): String { - val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true) + val stageName = fullyQualifiedStageName(tableName, true) return """ PUT 'file://$tempFilePath' '@$stageName' AUTO_COMPRESS = FALSE @@ -334,35 +329,45 @@ class SnowflakeDirectLoadSqlGenerator( .andLog() } - fun copyFromStage(tableName: TableName, filename: String): String { - val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true) + fun copyFromStage( + tableName: TableName, + filename: String, + columnNames: List? = null + ): String { + val stageName = fullyQualifiedStageName(tableName, true) + val columnList = + columnNames?.let { names -> "(${names.joinToString(", ") { it.quote() }})" } ?: "" return """ - COPY INTO ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)} - FROM '@$stageName' - FILE_FORMAT = ( - TYPE = 'CSV' - COMPRESSION = GZIP - FIELD_DELIMITER = '$CSV_FIELD_SEPARATOR' - RECORD_DELIMITER = '$CSV_LINE_DELIMITER' - FIELD_OPTIONALLY_ENCLOSED_BY = '"' - TRIM_SPACE = TRUE - ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE - REPLACE_INVALID_CHARACTERS = TRUE - ESCAPE = NONE - ESCAPE_UNENCLOSED_FIELD = NONE - ) - ON_ERROR = 'ABORT_STATEMENT' - PURGE = TRUE - files = ('$filename') + |COPY INTO ${fullyQualifiedName(tableName)}$columnList + |FROM '@$stageName' + |FILE_FORMAT = ( + | TYPE = 'CSV' + | COMPRESSION = GZIP + | FIELD_DELIMITER = '$CSV_FIELD_SEPARATOR' + | RECORD_DELIMITER = '$CSV_LINE_DELIMITER' + | FIELD_OPTIONALLY_ENCLOSED_BY = '"' + | TRIM_SPACE = TRUE + | ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE + | REPLACE_INVALID_CHARACTERS = TRUE + | ESCAPE = NONE + | ESCAPE_UNENCLOSED_FIELD = NONE + |) + |ON_ERROR = 'ABORT_STATEMENT' + |PURGE = TRUE + |files = ('$filename') """ - .trimIndent() + .trimMargin() .andLog() } fun swapTableWith(sourceTableName: TableName, targetTableName: TableName): String { return """ - ALTER TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} SWAP WITH ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} + ALTER TABLE ${fullyQualifiedName(sourceTableName)} SWAP WITH ${ + fullyQualifiedName( + targetTableName, + ) + } """ .trimIndent() .andLog() @@ -372,7 +377,11 @@ class SnowflakeDirectLoadSqlGenerator( // Snowflake RENAME TO only accepts the table name, not a fully qualified name // The renamed table stays in the same schema return """ - ALTER TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} RENAME TO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} + ALTER TABLE ${fullyQualifiedName(sourceTableName)} RENAME TO ${ + fullyQualifiedName( + targetTableName, + ) + } """ .trimIndent() .andLog() @@ -382,7 +391,7 @@ class SnowflakeDirectLoadSqlGenerator( schemaName: String, tableName: String, ): String = - """DESCRIBE TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(TableName(schemaName, tableName))}""".andLog() + """DESCRIBE TABLE ${fullyQualifiedName(TableName(schemaName, tableName))}""".andLog() fun alterTable( tableName: TableName, @@ -391,14 +400,14 @@ class SnowflakeDirectLoadSqlGenerator( modifiedColumns: Map, ): Set { val clauses = mutableSetOf() - val prettyTableName = snowflakeSqlNameUtils.fullyQualifiedName(tableName) + val prettyTableName = fullyQualifiedName(tableName) addedColumns.forEach { (name, columnType) -> clauses.add( // Note that we intentionally don't set NOT NULL. // We're adding a new column, and we don't know what constitutes a reasonable // default value for preexisting records. // So we add the column as nullable. - "ALTER TABLE $prettyTableName ADD COLUMN ${name.quote()} ${columnType.type};".andLog() + "ALTER TABLE $prettyTableName ADD COLUMN ${name.quote()} ${columnType.type};".andLog(), ) } deletedColumns.forEach { @@ -412,35 +421,34 @@ class SnowflakeDirectLoadSqlGenerator( val tempColumn = "${name}_${uuidGenerator.v4()}" clauses.add( // As above: we add the column as nullable. - "ALTER TABLE $prettyTableName ADD COLUMN ${tempColumn.quote()} ${typeChange.newType.type};".andLog() + "ALTER TABLE $prettyTableName ADD COLUMN ${tempColumn.quote()} ${typeChange.newType.type};".andLog(), ) clauses.add( - "UPDATE $prettyTableName SET ${tempColumn.quote()} = CAST(${name.quote()} AS ${typeChange.newType.type});".andLog() + "UPDATE $prettyTableName SET ${tempColumn.quote()} = CAST(${name.quote()} AS ${typeChange.newType.type});".andLog(), ) val backupColumn = "${tempColumn}_backup" clauses.add( """ ALTER TABLE $prettyTableName RENAME COLUMN "$name" TO "$backupColumn"; - """.trimIndent() + """.trimIndent(), ) clauses.add( """ ALTER TABLE $prettyTableName RENAME COLUMN "$tempColumn" TO "$name"; - """.trimIndent() + """.trimIndent(), ) clauses.add( - "ALTER TABLE $prettyTableName DROP COLUMN ${backupColumn.quote()};".andLog() + "ALTER TABLE $prettyTableName DROP COLUMN ${backupColumn.quote()};".andLog(), ) } else if (!typeChange.originalType.nullable && typeChange.newType.nullable) { // If the type is unchanged, we can change a column from NOT NULL to nullable. // But we'll never do the reverse, because there's a decent chance that historical - // records - // had null values. + // records had null values. // Users can always manually ALTER COLUMN ... SET NOT NULL if they want. clauses.add( - """ALTER TABLE $prettyTableName ALTER COLUMN "$name" DROP NOT NULL;""".andLog() + """ALTER TABLE $prettyTableName ALTER COLUMN "$name" DROP NOT NULL;""".andLog(), ) } else { log.info { @@ -450,4 +458,45 @@ class SnowflakeDirectLoadSqlGenerator( } return clauses } + + @VisibleForTesting + fun fullyQualifiedName(tableName: TableName): String = + combineParts(listOf(getDatabaseName(), tableName.namespace, tableName.name)) + + @VisibleForTesting + fun fullyQualifiedNamespace(namespace: String) = + combineParts(listOf(getDatabaseName(), namespace)) + + @VisibleForTesting + fun fullyQualifiedStageName(tableName: TableName, escape: Boolean = false): String { + val currentTableName = + if (escape) { + tableName.name + } else { + tableName.name + } + return combineParts( + parts = + listOf( + getDatabaseName(), + tableName.namespace, + "$STAGE_NAME_PREFIX$currentTableName", + ), + escape = escape, + ) + } + + @VisibleForTesting + internal fun combineParts(parts: List, escape: Boolean = false): String = + parts + .map { if (escape) sqlEscape(it) else it } + .joinToString(separator = ".") { + if (!it.startsWith(QUOTE)) { + "$QUOTE$it$QUOTE" + } else { + it + } + } + + private fun getDatabaseName() = config.database.toSnowflakeCompatibleName() } diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeSqlEscapeUtils.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeSqlEscapeUtils.kt new file mode 100644 index 00000000000..f7071b8c085 --- /dev/null +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeSqlEscapeUtils.kt @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2025 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.integrations.destination.snowflake.sql + +const val STAGE_NAME_PREFIX = "airbyte_stage_" +internal const val QUOTE: String = "\"" + +fun sqlEscape(part: String) = part.replace("\\", "\\\\").replace("'", "\\'").replace("\"", "\\\"") + +/** + * Surrounds the string instance with double quotation marks (e.g. "some string" -> "\"some + * string\""). + */ +fun String.quote() = "$QUOTE$this$QUOTE" + +/** + * Escapes double-quotes in a JSON identifier by doubling them. This is legacy -- I don't know why + * this would be necessary but no harm in keeping it, so I am keeping it. + * + * @return The escaped identifier. + */ +fun escapeJsonIdentifier(identifier: String): String { + // Note that we don't need to escape backslashes here! + // The only special character in an identifier is the double-quote, which needs to be + // doubled. + return identifier.replace(QUOTE, "$QUOTE$QUOTE") +} diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeSqlNameUtils.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeSqlNameUtils.kt deleted file mode 100644 index 1f981f04ecd..00000000000 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeSqlNameUtils.kt +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright (c) 2025 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.integrations.destination.snowflake.sql - -import io.airbyte.cdk.load.table.TableName -import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName -import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration -import jakarta.inject.Singleton - -const val STAGE_NAME_PREFIX = "airbyte_stage_" -internal const val QUOTE: String = "\"" - -fun sqlEscape(part: String) = part.replace("\\", "\\\\").replace("'", "\\'").replace("\"", "\\\"") - -@Singleton -class SnowflakeSqlNameUtils( - private val snowflakeConfiguration: SnowflakeConfiguration, -) { - fun fullyQualifiedName(tableName: TableName): String = - combineParts(listOf(getDatabaseName(), tableName.namespace, tableName.name)) - - fun fullyQualifiedNamespace(namespace: String) = - combineParts(listOf(getDatabaseName(), namespace)) - - fun fullyQualifiedStageName(tableName: TableName, escape: Boolean = false): String { - val currentTableName = - if (escape) { - tableName.name - } else { - tableName.name - } - return combineParts( - parts = - listOf( - getDatabaseName(), - tableName.namespace, - "$STAGE_NAME_PREFIX$currentTableName" - ), - escape = escape, - ) - } - - fun combineParts(parts: List, escape: Boolean = false): String = - parts - .map { if (escape) sqlEscape(it) else it } - .joinToString(separator = ".") { - if (!it.startsWith(QUOTE)) { - "$QUOTE$it$QUOTE" - } else { - it - } - } - - private fun getDatabaseName() = snowflakeConfiguration.database.toSnowflakeCompatibleName() -} diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeWriter.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeWriter.kt index debb4bba2db..36eff5bec03 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeWriter.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeWriter.kt @@ -6,38 +6,39 @@ package io.airbyte.integrations.destination.snowflake.write import io.airbyte.cdk.SystemErrorException import io.airbyte.cdk.load.command.Dedupe +import io.airbyte.cdk.load.command.DestinationCatalog import io.airbyte.cdk.load.command.DestinationStream -import io.airbyte.cdk.load.orchestration.db.DatabaseInitialStatusGatherer -import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator -import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadInitialStatus -import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendStreamLoader -import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendTruncateStreamLoader -import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableDedupStreamLoader -import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableDedupTruncateStreamLoader -import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig -import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog +import io.airbyte.cdk.load.table.ColumnNameMapping +import io.airbyte.cdk.load.table.DatabaseInitialStatusGatherer +import io.airbyte.cdk.load.table.TempTableNameGenerator +import io.airbyte.cdk.load.table.directload.DirectLoadInitialStatus +import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendStreamLoader +import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendTruncateStreamLoader +import io.airbyte.cdk.load.table.directload.DirectLoadTableDedupStreamLoader +import io.airbyte.cdk.load.table.directload.DirectLoadTableDedupTruncateStreamLoader +import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig import io.airbyte.cdk.load.write.DestinationWriter import io.airbyte.cdk.load.write.StreamLoader import io.airbyte.cdk.load.write.StreamStateStore import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient -import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration +import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier import jakarta.inject.Singleton @Singleton class SnowflakeWriter( - private val names: TableCatalog, + private val catalog: DestinationCatalog, private val stateGatherer: DatabaseInitialStatusGatherer, private val streamStateStore: StreamStateStore, private val snowflakeClient: SnowflakeAirbyteClient, - private val tempTableNameGenerator: TempTableNameGenerator, private val snowflakeConfiguration: SnowflakeConfiguration, + private val tempTableNameGenerator: TempTableNameGenerator, ) : DestinationWriter { private lateinit var initialStatuses: Map override suspend fun setup() { - names.values - .map { (tableNames, _) -> tableNames.finalTableName!!.namespace } + catalog.streams + .map { it.tableSchema.tableNames.finalTableName!!.namespace } .toSet() .forEach { snowflakeClient.createNamespace(it) } @@ -45,15 +46,15 @@ class SnowflakeWriter( escapeJsonIdentifier(snowflakeConfiguration.internalTableSchema) ) - initialStatuses = stateGatherer.gatherInitialStatus(names) + initialStatuses = stateGatherer.gatherInitialStatus() } override fun createStreamLoader(stream: DestinationStream): StreamLoader { val initialStatus = initialStatuses[stream]!! - val tableNameInfo = names[stream]!! - val realTableName = tableNameInfo.tableNames.finalTableName!! - val tempTableName = tempTableNameGenerator.generate(realTableName) - val columnNameMapping = tableNameInfo.columnNameMapping + val realTableName = stream.tableSchema.tableNames.finalTableName!! + val tempTableName = stream.tableSchema.tableNames.tempTableName!! + val columnNameMapping = + ColumnNameMapping(stream.tableSchema.columnSchema.inputToFinalColumnNames) return when (stream.minimumGenerationId) { 0L -> when (stream.importType) { diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeInsertBuffer.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeInsertBuffer.kt index 9f61ae1d47b..ee0945081ad 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeInsertBuffer.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeInsertBuffer.kt @@ -9,11 +9,12 @@ import de.siegmar.fastcsv.writer.CsvWriter import de.siegmar.fastcsv.writer.LineDelimiter import de.siegmar.fastcsv.writer.QuoteStrategies import io.airbyte.cdk.load.data.AirbyteValue -import io.airbyte.cdk.load.table.TableName +import io.airbyte.cdk.load.schema.model.ColumnSchema +import io.airbyte.cdk.load.schema.model.TableName import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient +import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.sql.QUOTE -import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils import io.github.oshai.kotlinlogging.KotlinLogging import java.io.File import java.io.OutputStream @@ -36,10 +37,11 @@ private const val CSV_WRITER_BUFFER_SIZE = 1024 * 1024 // 1 MB class SnowflakeInsertBuffer( private val tableName: TableName, - val columns: LinkedHashMap, private val snowflakeClient: SnowflakeAirbyteClient, val snowflakeConfiguration: SnowflakeConfiguration, - val snowflakeColumnUtils: SnowflakeColumnUtils, + val columnSchema: ColumnSchema, + private val columnManager: SnowflakeColumnManager, + private val snowflakeRecordFormatter: SnowflakeRecordFormatter, private val flushLimit: Int = DEFAULT_FLUSH_LIMIT, ) { @@ -57,12 +59,6 @@ class SnowflakeInsertBuffer( .lineDelimiter(CSV_LINE_DELIMITER) .quoteStrategy(QuoteStrategies.REQUIRED) - private val snowflakeRecordFormatter: SnowflakeRecordFormatter = - when (snowflakeConfiguration.legacyRawTablesOnly) { - true -> SnowflakeRawRecordFormatter(columns, snowflakeColumnUtils) - else -> SnowflakeSchemaRecordFormatter(columns, snowflakeColumnUtils) - } - fun accumulate(recordFields: Map) { if (csvFilePath == null) { val csvFile = createCsvFile() @@ -92,7 +88,9 @@ class SnowflakeInsertBuffer( "Copying staging data into ${tableName.toPrettyString(quote = QUOTE)}..." } // Finally, copy the data from the staging table to the final table - snowflakeClient.copyFromStage(tableName, filePath.fileName.toString()) + // Pass column names to ensure correct mapping even after ALTER TABLE operations + val columnNames = columnManager.getTableColumnNames(columnSchema) + snowflakeClient.copyFromStage(tableName, filePath.fileName.toString(), columnNames) logger.info { "Finished insert of $recordCount row(s) into ${tableName.toPrettyString(quote = QUOTE)}." } @@ -117,7 +115,9 @@ class SnowflakeInsertBuffer( private fun writeToCsvFile(record: Map) { csvWriter?.let { - it.writeRecord(snowflakeRecordFormatter.format(record).map { col -> col.toString() }) + it.writeRecord( + snowflakeRecordFormatter.format(record, columnSchema).map { col -> col.toString() } + ) recordCount++ if ((recordCount % flushLimit) == 0) { it.flush() diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeRecordFormatter.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeRecordFormatter.kt index 654e9da65de..c51bfc724f7 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeRecordFormatter.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeRecordFormatter.kt @@ -8,99 +8,62 @@ import io.airbyte.cdk.load.data.AirbyteValue import io.airbyte.cdk.load.data.NullValue import io.airbyte.cdk.load.data.StringValue import io.airbyte.cdk.load.data.csv.toCsvValue -import io.airbyte.cdk.load.message.Meta +import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT +import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID +import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT +import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META +import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID +import io.airbyte.cdk.load.schema.model.ColumnSchema import io.airbyte.cdk.load.util.Jsons -import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName -import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils interface SnowflakeRecordFormatter { - fun format(record: Map): List + fun format(record: Map, columnSchema: ColumnSchema): List } -class SnowflakeSchemaRecordFormatter( - private val columns: LinkedHashMap, - val snowflakeColumnUtils: SnowflakeColumnUtils, -) : SnowflakeRecordFormatter { +class SnowflakeSchemaRecordFormatter : SnowflakeRecordFormatter { + override fun format(record: Map, columnSchema: ColumnSchema): List { + val result = mutableListOf() + val userColumns = columnSchema.finalSchema.keys - private val airbyteColumnNames = - snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet() + // WARNING: MUST match the order defined in SnowflakeColumnManager#getTableColumnNames + // + // Why don't we just use that here? Well, unlike the user fields, the meta fields on the + // record are not munged for the destination. So we must access the values for those columns + // using the original lowercase meta key. + result.add(record[COLUMN_NAME_AB_RAW_ID].toCsvValue()) + result.add(record[COLUMN_NAME_AB_EXTRACTED_AT].toCsvValue()) + result.add(record[COLUMN_NAME_AB_META].toCsvValue()) + result.add(record[COLUMN_NAME_AB_GENERATION_ID].toCsvValue()) - override fun format(record: Map): List = - columns.map { (columnName, _) -> - /* - * Meta columns are forced to uppercase for backwards compatibility with previous - * versions of the destination. Therefore, convert the column to lowercase so - * that it can match the constants, which use the lowercase version of the meta - * column names. - */ - if (airbyteColumnNames.contains(columnName)) { - record[columnName.lowercase()].toCsvValue() - } else { - record.keys - // The columns retrieved from Snowflake do not have any escaping applied. - // Therefore, re-apply the compatible name escaping to the name of the - // columns retrieved from Snowflake. The record keys should already have - // been escaped by the CDK before arriving at the aggregate, so no need - // to escape again here. - .find { it == columnName.toSnowflakeCompatibleName() } - ?.let { record[it].toCsvValue() } - ?: "" - } - } + // Add user columns from the final schema + userColumns.forEach { columnName -> result.add(record[columnName].toCsvValue()) } + + return result + } } -class SnowflakeRawRecordFormatter( - columns: LinkedHashMap, - val snowflakeColumnUtils: SnowflakeColumnUtils, -) : SnowflakeRecordFormatter { - private val columns = columns.keys +class SnowflakeRawRecordFormatter : SnowflakeRecordFormatter { - private val airbyteColumnNames = - snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet() - - override fun format(record: Map): List = + override fun format(record: Map, columnSchema: ColumnSchema): List = toOutputRecord(record.toMutableMap()) private fun toOutputRecord(record: MutableMap): List { val outputRecord = mutableListOf() - // Copy the Airbyte metadata columns to the raw output, removing each - // one from the record to avoid duplicates in the "data" field - columns - .filter { airbyteColumnNames.contains(it) && it != Meta.COLUMN_NAME_DATA } - .forEach { column -> safeAddToOutput(column, record, outputRecord) } + val mutableRecord = record.toMutableMap() + + // Add meta columns in order (except _airbyte_data which we handle specially) + outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_RAW_ID)?.toCsvValue() ?: "") + outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_EXTRACTED_AT)?.toCsvValue() ?: "") + outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_META)?.toCsvValue() ?: "") + outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_GENERATION_ID)?.toCsvValue() ?: "") + outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_LOADED_AT)?.toCsvValue() ?: "") + // Do not output null values in the JSON raw output - val filteredRecord = record.filter { (_, v) -> v !is NullValue } - // Convert all the remaining columns in the record to a JSON document stored in the "data" - // column. Add it in the same position as the _airbyte_data column in the column list to - // ensure it is inserted into the proper column in the table. - insert( - columns.indexOf(Meta.COLUMN_NAME_DATA), - StringValue(Jsons.writeValueAsString(filteredRecord)).toCsvValue(), - outputRecord - ) + val filteredRecord = mutableRecord.filter { (_, v) -> v !is NullValue } + + // Convert all the remaining columns to a JSON document stored in the "data" column + outputRecord.add(StringValue(Jsons.writeValueAsString(filteredRecord)).toCsvValue()) + return outputRecord } - - private fun safeAddToOutput( - key: String, - record: MutableMap, - output: MutableList - ) { - val extractedValue = record.remove(key) - // Ensure that the data is inserted into the list at the same position as the column - insert(columns.indexOf(key), extractedValue?.toCsvValue() ?: "", output) - } - - private fun insert(index: Int, value: Any, list: MutableList) { - /* - * Attempt to insert the value into the proper order in the list. If the index - * is already present in the list, use the add(index, element) method to insert it - * into the proper order and push everything to the right. If the index is at the - * end of the list, just use add(element) to insert it at the end. If the index - * is further beyond the end of the list, throw an exception as that should not occur. - */ - if (index < list.size) list.add(index, value) - else if (index == list.size || index == list.size + 1) list.add(value) - else throw IndexOutOfBoundsException() - } } diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/write/transform/SnowflakeColumnNameMapper.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/write/transform/SnowflakeColumnNameMapper.kt deleted file mode 100644 index d4d5b4e9d71..00000000000 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/write/transform/SnowflakeColumnNameMapper.kt +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright (c) 2025 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.integrations.destination.snowflake.write.transform - -import io.airbyte.cdk.load.command.DestinationStream -import io.airbyte.cdk.load.dataflow.transform.ColumnNameMapper -import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog -import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration -import jakarta.inject.Singleton - -@Singleton -class SnowflakeColumnNameMapper( - private val catalogInfo: TableCatalog, - private val snowflakeConfiguration: SnowflakeConfiguration, -) : ColumnNameMapper { - override fun getMappedColumnName(stream: DestinationStream, columnName: String): String { - if (snowflakeConfiguration.legacyRawTablesOnly == true) { - return columnName - } else { - return catalogInfo.getMappedColumnName(stream, columnName)!! - } - } -} diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/SnowflakeTableOperationsTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/SnowflakeTableOperationsTest.kt index ccfd42096a6..0c2cbfb1ff3 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/SnowflakeTableOperationsTest.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/SnowflakeTableOperationsTest.kt @@ -7,9 +7,11 @@ package io.airbyte.integrations.destination.snowflake.component import io.airbyte.cdk.load.component.TableOperationsFixtures import io.airbyte.cdk.load.component.TableOperationsSuite import io.airbyte.cdk.load.message.Meta +import io.airbyte.cdk.load.schema.TableSchemaFactory import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient -import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.idTestWithCdcMapping -import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.testMapping +import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.idTestWithCdcMapping +import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.testMapping +import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeTestTableOperationsClient import io.micronaut.test.extensions.junit5.annotation.MicronautTest import org.junit.jupiter.api.Test import org.junit.jupiter.api.parallel.Execution @@ -20,6 +22,7 @@ import org.junit.jupiter.api.parallel.ExecutionMode class SnowflakeTableOperationsTest( override val client: SnowflakeAirbyteClient, override val testClient: SnowflakeTestTableOperationsClient, + override val schemaFactory: TableSchemaFactory, ) : TableOperationsSuite { override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() } diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/SnowflakeTableSchemaEvolutionTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/SnowflakeTableSchemaEvolutionTest.kt index 0a27ef35d7b..7f632b136c6 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/SnowflakeTableSchemaEvolutionTest.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/SnowflakeTableSchemaEvolutionTest.kt @@ -9,13 +9,15 @@ import io.airbyte.cdk.load.component.TableSchemaEvolutionFixtures import io.airbyte.cdk.load.component.TableSchemaEvolutionSuite import io.airbyte.cdk.load.data.StringValue import io.airbyte.cdk.load.message.Meta +import io.airbyte.cdk.load.schema.TableSchemaFactory import io.airbyte.cdk.load.table.ColumnNameMapping import io.airbyte.cdk.load.util.serializeToString import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient -import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.allTypesColumnNameMapping -import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.allTypesTableSchema -import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.idAndTestMapping -import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.testMapping +import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.allTypesColumnNameMapping +import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.allTypesTableSchema +import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.idAndTestMapping +import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.testMapping +import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeTestTableOperationsClient import io.micronaut.test.extensions.junit5.annotation.MicronautTest import org.junit.jupiter.api.Test import org.junit.jupiter.api.parallel.Execution @@ -27,6 +29,7 @@ class SnowflakeTableSchemaEvolutionTest( override val client: SnowflakeAirbyteClient, override val opsClient: SnowflakeAirbyteClient, override val testClient: SnowflakeTestTableOperationsClient, + override val schemaFactory: TableSchemaFactory, ) : TableSchemaEvolutionSuite { override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() } diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/SnowflakeComponentTestFixtures.kt b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/config/SnowflakeComponentTestFixtures.kt similarity index 90% rename from airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/SnowflakeComponentTestFixtures.kt rename to airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/config/SnowflakeComponentTestFixtures.kt index ee5447385fb..8ff990fab04 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/SnowflakeComponentTestFixtures.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/config/SnowflakeComponentTestFixtures.kt @@ -2,7 +2,7 @@ * Copyright (c) 2025 Airbyte, Inc., all rights reserved. */ -package io.airbyte.integrations.destination.snowflake.component +package io.airbyte.integrations.destination.snowflake.component.config import io.airbyte.cdk.load.component.ColumnType import io.airbyte.cdk.load.component.TableOperationsFixtures @@ -33,6 +33,8 @@ object SnowflakeComponentTestFixtures { "TIME_NTZ" to ColumnType("TIME", true), "ARRAY" to ColumnType("ARRAY", true), "OBJECT" to ColumnType("OBJECT", true), + "UNION" to ColumnType("VARIANT", true), + "LEGACY_UNION" to ColumnType("VARIANT", true), "UNKNOWN" to ColumnType("VARIANT", true), ) ) diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/SnowflakeTestConfigFactory.kt b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/config/SnowflakeTestConfigFactory.kt similarity index 92% rename from airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/SnowflakeTestConfigFactory.kt rename to airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/config/SnowflakeTestConfigFactory.kt index 7ee247aabc0..75e9d824255 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/SnowflakeTestConfigFactory.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/config/SnowflakeTestConfigFactory.kt @@ -2,7 +2,7 @@ * Copyright (c) 2025 Airbyte, Inc., all rights reserved. */ -package io.airbyte.integrations.destination.snowflake.component +package io.airbyte.integrations.destination.snowflake.component.config import io.airbyte.cdk.load.component.config.TestConfigLoader.loadTestConfig import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/SnowflakeTestTableOperationsClient.kt b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/config/SnowflakeTestTableOperationsClient.kt similarity index 75% rename from airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/SnowflakeTestTableOperationsClient.kt rename to airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/config/SnowflakeTestTableOperationsClient.kt index 986d9810cff..1ce0b8e43f7 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/SnowflakeTestTableOperationsClient.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/component/config/SnowflakeTestTableOperationsClient.kt @@ -2,22 +2,24 @@ * Copyright (c) 2025 Airbyte, Inc., all rights reserved. */ -package io.airbyte.integrations.destination.snowflake.component +package io.airbyte.integrations.destination.snowflake.component.config +import io.airbyte.cdk.load.component.ColumnType import io.airbyte.cdk.load.component.TestTableOperationsClient import io.airbyte.cdk.load.data.AirbyteValue import io.airbyte.cdk.load.dataflow.state.PartitionKey import io.airbyte.cdk.load.dataflow.transform.RecordDTO -import io.airbyte.cdk.load.table.TableName +import io.airbyte.cdk.load.schema.model.TableName import io.airbyte.cdk.load.util.Jsons import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.execute import io.airbyte.integrations.destination.snowflake.dataflow.SnowflakeAggregate +import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration -import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils -import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils +import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator import io.airbyte.integrations.destination.snowflake.sql.andLog import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer +import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter import io.micronaut.context.annotation.Requires import jakarta.inject.Singleton import java.time.format.DateTimeFormatter @@ -29,25 +31,40 @@ import net.snowflake.client.jdbc.SnowflakeTimestampWithTimezone class SnowflakeTestTableOperationsClient( private val client: SnowflakeAirbyteClient, private val dataSource: DataSource, - private val snowflakeSqlNameUtils: SnowflakeSqlNameUtils, - private val snowflakeColumnUtils: SnowflakeColumnUtils, + private val sqlGenerator: SnowflakeDirectLoadSqlGenerator, private val snowflakeConfiguration: SnowflakeConfiguration, + private val columnManager: SnowflakeColumnManager, + private val snowflakeRecordFormatter: SnowflakeRecordFormatter, ) : TestTableOperationsClient { override suspend fun dropNamespace(namespace: String) { dataSource.execute( - "DROP SCHEMA IF EXISTS ${snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)}".andLog() + "DROP SCHEMA IF EXISTS ${sqlGenerator.fullyQualifiedNamespace(namespace)}".andLog() ) } override suspend fun insertRecords(table: TableName, records: List>) { + // TODO: we should just pass a proper column schema + // Since we don't pass in a proper column schema, we have to recreate one here + // Fetch the columns and filter out the meta columns so we're just looking at user columns + val columnTypes = + client.describeTable(table).filterNot { + columnManager.getMetaColumnNames().contains(it.key) + } + val columnSchema = + io.airbyte.cdk.load.schema.model.ColumnSchema( + inputToFinalColumnNames = columnTypes.keys.associateWith { it }, + finalSchema = columnTypes.mapValues { (_, _) -> ColumnType("", true) }, + inputSchema = emptyMap() // Not needed for insert buffer + ) val a = SnowflakeAggregate( SnowflakeInsertBuffer( - table, - client.describeTable(table), - client, - snowflakeConfiguration, - snowflakeColumnUtils, + tableName = table, + snowflakeClient = client, + snowflakeConfiguration = snowflakeConfiguration, + columnSchema = columnSchema, + columnManager = columnManager, + snowflakeRecordFormatter = snowflakeRecordFormatter, ) ) records.forEach { a.accept(RecordDTO(it, PartitionKey(""), 0, 0)) } diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeDataCleaner.kt b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeDataCleaner.kt index 7c1b949c9c1..053d15184b0 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeDataCleaner.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeDataCleaner.kt @@ -7,11 +7,11 @@ package io.airbyte.integrations.destination.snowflake.write import io.airbyte.cdk.load.test.util.DestinationCleaner import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier -import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier -import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName +import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory import io.airbyte.integrations.destination.snowflake.sql.STAGE_NAME_PREFIX +import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier import io.airbyte.integrations.destination.snowflake.sql.quote import java.nio.file.Files import java.sql.Connection diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeDataDumper.kt b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeDataDumper.kt index ad75f1b8746..9a6f45a284a 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeDataDumper.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeDataDumper.kt @@ -12,16 +12,22 @@ import io.airbyte.cdk.load.data.ObjectValue import io.airbyte.cdk.load.data.json.toAirbyteValue import io.airbyte.cdk.load.message.Meta import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN +import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator import io.airbyte.cdk.load.test.util.DestinationDataDumper import io.airbyte.cdk.load.test.util.OutputRecord +import io.airbyte.cdk.load.util.UUIDGenerator import io.airbyte.cdk.load.util.deserializeToNode import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory -import io.airbyte.integrations.destination.snowflake.db.SnowflakeFinalTableNameGenerator -import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName +import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager +import io.airbyte.integrations.destination.snowflake.schema.SnowflakeTableSchemaMapper +import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration -import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils +import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator import io.airbyte.integrations.destination.snowflake.sql.sqlEscape import java.math.BigDecimal +import java.sql.Date +import java.sql.Time +import java.sql.Timestamp import net.snowflake.client.jdbc.SnowflakeTimestampWithTimezone private val AIRBYTE_META_COLUMNS = Meta.COLUMN_NAMES + setOf(CDC_DELETED_AT_COLUMN) @@ -34,8 +40,14 @@ class SnowflakeDataDumper( stream: DestinationStream ): List { val config = configProvider(spec) - val sqlUtils = SnowflakeSqlNameUtils(config) - val snowflakeFinalTableNameGenerator = SnowflakeFinalTableNameGenerator(config) + val snowflakeFinalTableNameGenerator = + SnowflakeTableSchemaMapper( + config = config, + tempTableNameGenerator = DefaultTempTableNameGenerator(), + ) + val snowflakeColumnManager = SnowflakeColumnManager(config) + val sqlGenerator = + SnowflakeDirectLoadSqlGenerator(UUIDGenerator(), config, snowflakeColumnManager) val dataSource = SnowflakeBeanFactory() .snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY") @@ -46,7 +58,7 @@ class SnowflakeDataDumper( ds.connection.use { connection -> val statement = connection.createStatement() val tableName = - snowflakeFinalTableNameGenerator.getTableName(stream.mappedDescriptor) + snowflakeFinalTableNameGenerator.toFinalTableName(stream.mappedDescriptor) // First check if the table exists val tableExistsQuery = @@ -69,7 +81,7 @@ class SnowflakeDataDumper( val resultSet = statement.executeQuery( - "SELECT * FROM ${sqlUtils.fullyQualifiedName(tableName)}" + "SELECT * FROM ${sqlGenerator.fullyQualifiedName(tableName)}" ) while (resultSet.next()) { @@ -143,10 +155,10 @@ class SnowflakeDataDumper( private fun convertValue(value: Any?): Any? = when (value) { is BigDecimal -> value.toBigInteger() - is java.sql.Date -> value.toLocalDate() + is Date -> value.toLocalDate() is SnowflakeTimestampWithTimezone -> value.toZonedDateTime() - is java.sql.Time -> value.toLocalTime() - is java.sql.Timestamp -> value.toLocalDateTime() + is Time -> value.toLocalTime() + is Timestamp -> value.toLocalDateTime() else -> value } } diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeExpectedRecordMapper.kt b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeExpectedRecordMapper.kt index d6cc586298f..ab754d09b0a 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeExpectedRecordMapper.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeExpectedRecordMapper.kt @@ -15,7 +15,7 @@ import io.airbyte.cdk.load.dataflow.transform.ValidationResult import io.airbyte.cdk.load.message.Meta import io.airbyte.cdk.load.test.util.ExpectedRecordMapper import io.airbyte.cdk.load.test.util.OutputRecord -import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName +import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.write.transform.SnowflakeValueCoercer import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Change diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeNameMapper.kt b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeNameMapper.kt index 7c49ab3c589..f1b23cd2416 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeNameMapper.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeNameMapper.kt @@ -5,7 +5,7 @@ package io.airbyte.integrations.destination.snowflake.write import io.airbyte.cdk.load.test.util.NameMapper -import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName +import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName class SnowflakeNameMapper : NameMapper { override fun mapFieldName(path: List): List = diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeRawDataDumper.kt b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeRawDataDumper.kt index 759afa5869e..e08572f25df 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeRawDataDumper.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeRawDataDumper.kt @@ -9,13 +9,16 @@ import io.airbyte.cdk.load.command.DestinationStream import io.airbyte.cdk.load.data.StringValue import io.airbyte.cdk.load.data.json.toAirbyteValue import io.airbyte.cdk.load.message.Meta +import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator import io.airbyte.cdk.load.test.util.DestinationDataDumper import io.airbyte.cdk.load.test.util.OutputRecord +import io.airbyte.cdk.load.util.UUIDGenerator import io.airbyte.cdk.load.util.deserializeToNode import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory -import io.airbyte.integrations.destination.snowflake.db.SnowflakeFinalTableNameGenerator +import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager +import io.airbyte.integrations.destination.snowflake.schema.SnowflakeTableSchemaMapper import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration -import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils +import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator class SnowflakeRawDataDumper( private val configProvider: (ConfigurationSpecification) -> SnowflakeConfiguration @@ -27,8 +30,18 @@ class SnowflakeRawDataDumper( val output = mutableListOf() val config = configProvider(spec) - val sqlUtils = SnowflakeSqlNameUtils(config) - val snowflakeFinalTableNameGenerator = SnowflakeFinalTableNameGenerator(config) + val snowflakeColumnManager = SnowflakeColumnManager(config) + val sqlGenerator = + SnowflakeDirectLoadSqlGenerator( + UUIDGenerator(), + config, + snowflakeColumnManager, + ) + val snowflakeFinalTableNameGenerator = + SnowflakeTableSchemaMapper( + config = config, + tempTableNameGenerator = DefaultTempTableNameGenerator(), + ) val dataSource = SnowflakeBeanFactory() .snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY") @@ -37,11 +50,11 @@ class SnowflakeRawDataDumper( ds.connection.use { connection -> val statement = connection.createStatement() val tableName = - snowflakeFinalTableNameGenerator.getTableName(stream.mappedDescriptor) + snowflakeFinalTableNameGenerator.toFinalTableName(stream.mappedDescriptor) val resultSet = statement.executeQuery( - "SELECT * FROM ${sqlUtils.fullyQualifiedName(tableName)}" + "SELECT * FROM ${sqlGenerator.fullyQualifiedName(tableName)}" ) while (resultSet.next()) { diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeBeanFactoryTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeBeanFactoryTest.kt index c5326e0025c..8ef15f637ef 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeBeanFactoryTest.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeBeanFactoryTest.kt @@ -6,7 +6,7 @@ package io.airbyte.integrations.destination.snowflake import com.zaxxer.hikari.HikariConfig import com.zaxxer.hikari.HikariDataSource -import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName +import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/check/SnowflakeCheckerTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/check/SnowflakeCheckerTest.kt index e69ca17d85b..ac6e892279f 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/check/SnowflakeCheckerTest.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/check/SnowflakeCheckerTest.kt @@ -5,11 +5,9 @@ package io.airbyte.integrations.destination.snowflake.check import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient -import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName +import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager +import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration -import io.airbyte.integrations.destination.snowflake.sql.DEFAULT_COLUMNS -import io.airbyte.integrations.destination.snowflake.sql.RAW_DATA_COLUMN -import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils import io.mockk.coEvery import io.mockk.coVerify import io.mockk.every @@ -23,47 +21,31 @@ internal class SnowflakeCheckerTest { @ParameterizedTest @ValueSource(booleans = [true, false]) fun testSuccessfulCheck(isLegacyRawTablesOnly: Boolean) { - val defaultColumnsMap = - if (isLegacyRawTablesOnly) { - linkedMapOf().also { map -> - (DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach { - map[it.columnName] = it.columnType - } - } - } else { - linkedMapOf().also { map -> - (DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach { - map[it.columnName.toSnowflakeCompatibleName()] = it.columnType - } - } - } - val defaultColumns = defaultColumnsMap.keys.toMutableList() val snowflakeAirbyteClient: SnowflakeAirbyteClient = - mockk(relaxed = true) { - coEvery { countTable(any()) } returns 1L - coEvery { describeTable(any()) } returns defaultColumnsMap - } + mockk(relaxed = true) { coEvery { countTable(any()) } returns 1L } val testSchema = "test-schema" val snowflakeConfiguration: SnowflakeConfiguration = mockk { every { schema } returns testSchema every { legacyRawTablesOnly } returns isLegacyRawTablesOnly } - val snowflakeColumnUtils = - mockk(relaxUnitFun = true) { - every { getFormattedDefaultColumnNames(any()) } returns defaultColumns - } + + val columnManager: SnowflakeColumnManager = SnowflakeColumnManager(snowflakeConfiguration) val checker = SnowflakeChecker( snowflakeAirbyteClient = snowflakeAirbyteClient, snowflakeConfiguration = snowflakeConfiguration, - snowflakeColumnUtils = snowflakeColumnUtils, + columnManager = columnManager, ) checker.check() coVerify(exactly = 1) { - snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName()) + if (isLegacyRawTablesOnly) { + snowflakeAirbyteClient.createNamespace(testSchema) + } else { + snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName()) + } } coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) } coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) } @@ -72,48 +54,32 @@ internal class SnowflakeCheckerTest { @ParameterizedTest @ValueSource(booleans = [true, false]) fun testUnsuccessfulCheck(isLegacyRawTablesOnly: Boolean) { - val defaultColumnsMap = - if (isLegacyRawTablesOnly) { - linkedMapOf().also { map -> - (DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach { - map[it.columnName] = it.columnType - } - } - } else { - linkedMapOf().also { map -> - (DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach { - map[it.columnName.toSnowflakeCompatibleName()] = it.columnType - } - } - } - val defaultColumns = defaultColumnsMap.keys.toMutableList() val snowflakeAirbyteClient: SnowflakeAirbyteClient = - mockk(relaxed = true) { - coEvery { countTable(any()) } returns 0L - coEvery { describeTable(any()) } returns defaultColumnsMap - } + mockk(relaxed = true) { coEvery { countTable(any()) } returns 0L } val testSchema = "test-schema" val snowflakeConfiguration: SnowflakeConfiguration = mockk { every { schema } returns testSchema every { legacyRawTablesOnly } returns isLegacyRawTablesOnly } - val snowflakeColumnUtils = - mockk(relaxUnitFun = true) { - every { getFormattedDefaultColumnNames(any()) } returns defaultColumns - } + + val columnManager: SnowflakeColumnManager = SnowflakeColumnManager(snowflakeConfiguration) val checker = SnowflakeChecker( snowflakeAirbyteClient = snowflakeAirbyteClient, snowflakeConfiguration = snowflakeConfiguration, - snowflakeColumnUtils = snowflakeColumnUtils, + columnManager = columnManager, ) assertThrows { checker.check() } coVerify(exactly = 1) { - snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName()) + if (isLegacyRawTablesOnly) { + snowflakeAirbyteClient.createNamespace(testSchema) + } else { + snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName()) + } } coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) } coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) } diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/client/SnowflakeAirbyteClientTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/client/SnowflakeAirbyteClientTest.kt index 1cdbf90ab48..784916b87eb 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/client/SnowflakeAirbyteClientTest.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/client/SnowflakeAirbyteClientTest.kt @@ -6,24 +6,16 @@ package io.airbyte.integrations.destination.snowflake.client import io.airbyte.cdk.ConfigErrorException import io.airbyte.cdk.load.command.DestinationStream -import io.airbyte.cdk.load.command.NamespaceMapper -import io.airbyte.cdk.load.command.Overwrite import io.airbyte.cdk.load.component.ColumnType -import io.airbyte.cdk.load.config.NamespaceDefinitionType -import io.airbyte.cdk.load.data.AirbyteType -import io.airbyte.cdk.load.data.FieldType import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID +import io.airbyte.cdk.load.schema.model.TableName import io.airbyte.cdk.load.table.ColumnNameMapping -import io.airbyte.cdk.load.table.TableName -import io.airbyte.integrations.destination.snowflake.db.ColumnDefinition -import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName +import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager +import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS -import io.airbyte.integrations.destination.snowflake.sql.ColumnAndType -import io.airbyte.integrations.destination.snowflake.sql.DEFAULT_COLUMNS import io.airbyte.integrations.destination.snowflake.sql.QUOTE -import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator import io.mockk.Runs import io.mockk.every @@ -49,31 +41,18 @@ internal class SnowflakeAirbyteClientTest { private lateinit var client: SnowflakeAirbyteClient private lateinit var dataSource: DataSource private lateinit var sqlGenerator: SnowflakeDirectLoadSqlGenerator - private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils private lateinit var snowflakeConfiguration: SnowflakeConfiguration + private lateinit var columnManager: SnowflakeColumnManager @BeforeEach fun setup() { dataSource = mockk() sqlGenerator = mockk(relaxed = true) - snowflakeColumnUtils = - mockk(relaxed = true) { - every { formatColumnName(any()) } answers - { - firstArg().toSnowflakeCompatibleName() - } - every { getFormattedDefaultColumnNames(any()) } returns - DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() } - } snowflakeConfiguration = mockk(relaxed = true) { every { database } returns "test_database" } + columnManager = mockk(relaxed = true) client = - SnowflakeAirbyteClient( - dataSource, - sqlGenerator, - snowflakeColumnUtils, - snowflakeConfiguration - ) + SnowflakeAirbyteClient(dataSource, sqlGenerator, snowflakeConfiguration, columnManager) } @Test @@ -231,7 +210,7 @@ internal class SnowflakeAirbyteClientTest { @Test fun testCreateTable() { val columnNameMapping = mockk(relaxed = true) - val stream = mockk() + val stream = mockk(relaxed = true) val tableName = TableName(namespace = "namespace", name = "name") val resultSet = mockk(relaxed = true) val statement = @@ -254,9 +233,7 @@ internal class SnowflakeAirbyteClientTest { columnNameMapping = columnNameMapping, replace = true, ) - verify(exactly = 1) { - sqlGenerator.createTable(stream, tableName, columnNameMapping, true) - } + verify(exactly = 1) { sqlGenerator.createTable(tableName, any(), true) } verify(exactly = 1) { sqlGenerator.createSnowflakeStage(tableName) } verify(exactly = 2) { mockConnection.close() } } @@ -288,7 +265,7 @@ internal class SnowflakeAirbyteClientTest { targetTableName = destinationTableName, ) verify(exactly = 1) { - sqlGenerator.copyTable(columnNameMapping, sourceTableName, destinationTableName) + sqlGenerator.copyTable(any>(), sourceTableName, destinationTableName) } verify(exactly = 1) { mockConnection.close() } } @@ -299,7 +276,7 @@ internal class SnowflakeAirbyteClientTest { val columnNameMapping = mockk(relaxed = true) val sourceTableName = TableName(namespace = "namespace", name = "source") val destinationTableName = TableName(namespace = "namespace", name = "destination") - val stream = mockk() + val stream = mockk(relaxed = true) val resultSet = mockk(relaxed = true) val statement = mockk { @@ -322,12 +299,7 @@ internal class SnowflakeAirbyteClientTest { targetTableName = destinationTableName, ) verify(exactly = 1) { - sqlGenerator.upsertTable( - stream, - columnNameMapping, - sourceTableName, - destinationTableName - ) + sqlGenerator.upsertTable(any(), sourceTableName, destinationTableName) } verify(exactly = 1) { mockConnection.close() } } @@ -379,7 +351,7 @@ internal class SnowflakeAirbyteClientTest { } every { dataSource.connection } returns mockConnection - every { snowflakeColumnUtils.getGenerationIdColumnName() } returns generationIdColumnName + every { columnManager.getGenerationIdColumnName() } returns generationIdColumnName every { sqlGenerator.getGenerationId(tableName) } returns "SELECT $generationIdColumnName FROM ${tableName.toPrettyString(QUOTE)}" @@ -501,8 +473,8 @@ internal class SnowflakeAirbyteClientTest { every { dataSource.connection } returns mockConnection runBlocking { - client.copyFromStage(tableName, "test.csv.gz") - verify(exactly = 1) { sqlGenerator.copyFromStage(tableName, "test.csv.gz") } + client.copyFromStage(tableName, "test.csv.gz", listOf()) + verify(exactly = 1) { sqlGenerator.copyFromStage(tableName, "test.csv.gz", listOf()) } verify(exactly = 1) { mockConnection.close() } } } @@ -556,7 +528,7 @@ internal class SnowflakeAirbyteClientTest { "COL1" andThen COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName() andThen "COL2" - every { resultSet.getString("type") } returns "VARCHAR(255)" andThen "NUMBER(38,0)" + every { resultSet.getString("type") } returns "VARCHAR(255)" andThen "NUMBER" every { resultSet.getString("null?") } returns "Y" andThen "N" andThen "N" val statement = @@ -571,6 +543,10 @@ internal class SnowflakeAirbyteClientTest { every { dataSource.connection } returns connection + // Mock the columnManager to return the correct set of meta columns + every { columnManager.getMetaColumnNames() } returns + setOf(COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName()) + val result = client.getColumnsFromDb(tableName) val expectedColumns = @@ -582,81 +558,6 @@ internal class SnowflakeAirbyteClientTest { assertEquals(expectedColumns, result) } - @Test - fun `getColumnsFromStream should return correct column definitions`() { - val schema = mockk() - val stream = - DestinationStream( - unmappedNamespace = "test_namespace", - unmappedName = "test_stream", - importType = Overwrite, - schema = schema, - generationId = 1, - minimumGenerationId = 1, - syncId = 1, - namespaceMapper = NamespaceMapper(NamespaceDefinitionType.DESTINATION) - ) - val columnNameMapping = - ColumnNameMapping( - mapOf( - "col1" to "COL1_MAPPED", - "col2" to "COL2_MAPPED", - ) - ) - - val col1FieldType = mockk() - every { col1FieldType.type } returns mockk() - - val col2FieldType = mockk() - every { col2FieldType.type } returns mockk() - - every { schema.asColumns() } returns - linkedMapOf("col1" to col1FieldType, "col2" to col2FieldType) - every { snowflakeColumnUtils.toDialectType(col1FieldType.type) } returns "VARCHAR(255)" - every { snowflakeColumnUtils.toDialectType(col2FieldType.type) } returns "NUMBER(38,0)" - every { snowflakeColumnUtils.columnsAndTypes(any(), any()) } returns - listOf(ColumnAndType("COL1_MAPPED", "VARCHAR"), ColumnAndType("COL2_MAPPED", "NUMBER")) - every { snowflakeColumnUtils.formatColumnName(any(), false) } answers - { - firstArg().toSnowflakeCompatibleName() - } - - val result = client.getColumnsFromStream(stream, columnNameMapping) - - val expectedColumns = - mapOf( - "COL1_MAPPED" to ColumnType("VARCHAR", true), - "COL2_MAPPED" to ColumnType("NUMBER", true), - ) - - assertEquals(expectedColumns, result) - } - - @Test - fun `generateSchemaChanges should correctly identify changes`() { - val columnsInDb = - setOf( - ColumnDefinition("COL1", "VARCHAR"), - ColumnDefinition("COL2", "NUMBER"), - ColumnDefinition("COL3", "BOOLEAN") - ) - val columnsInStream = - setOf( - ColumnDefinition("COL1", "VARCHAR"), // Unchanged - ColumnDefinition("COL3", "TEXT"), // Modified - ColumnDefinition("COL4", "DATE") // Added - ) - - val (added, deleted, modified) = client.generateSchemaChanges(columnsInDb, columnsInStream) - - assertEquals(1, added.size) - assertEquals("COL4", added.first().name) - assertEquals(1, deleted.size) - assertEquals("COL2", deleted.first().name) - assertEquals(1, modified.size) - assertEquals("COL3", modified.first().name) - } - @Test fun testCreateNamespaceWithNetworkFailure() { val namespace = "test_namespace" diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/dataflow/SnowflakeAggregateFactoryTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/dataflow/SnowflakeAggregateFactoryTest.kt index 570846fda9d..5d7378a29b1 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/dataflow/SnowflakeAggregateFactoryTest.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/dataflow/SnowflakeAggregateFactoryTest.kt @@ -4,14 +4,19 @@ package io.airbyte.integrations.destination.snowflake.dataflow +import io.airbyte.cdk.load.command.DestinationCatalog +import io.airbyte.cdk.load.command.DestinationStream import io.airbyte.cdk.load.command.DestinationStream.Descriptor import io.airbyte.cdk.load.dataflow.aggregate.StoreKey -import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig -import io.airbyte.cdk.load.table.TableName +import io.airbyte.cdk.load.schema.model.TableName +import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig import io.airbyte.cdk.load.write.StreamStateStore import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient +import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration -import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils +import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRawRecordFormatter +import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter +import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter import io.mockk.every import io.mockk.mockk import org.junit.jupiter.api.Assertions.assertEquals @@ -22,27 +27,33 @@ internal class SnowflakeAggregateFactoryTest { @Test fun testCreatingAggregateWithRawBuffer() { val descriptor = Descriptor(namespace = "namespace", name = "name") - val directLoadTableExecutionConfig = - DirectLoadTableExecutionConfig( - tableName = - TableName( - namespace = descriptor.namespace!!, - name = descriptor.name, - ) + val tableName = + TableName( + namespace = descriptor.namespace!!, + name = descriptor.name, ) + val directLoadTableExecutionConfig = DirectLoadTableExecutionConfig(tableName = tableName) val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name) val streamStore = StreamStateStore() - streamStore.put(descriptor, directLoadTableExecutionConfig) + streamStore.put(key, directLoadTableExecutionConfig) + + val stream = mockk(relaxed = true) + val catalog = mockk { every { getStream(key) } returns stream } + val snowflakeClient = mockk(relaxed = true) val snowflakeConfiguration = mockk { every { legacyRawTablesOnly } returns true } - val snowflakeColumnUtils = mockk(relaxed = true) + val columnManager = SnowflakeColumnManager(snowflakeConfiguration) + val snowflakeRecordFormatter: SnowflakeRecordFormatter = SnowflakeRawRecordFormatter() + val factory = SnowflakeAggregateFactory( snowflakeClient = snowflakeClient, streamStateStore = streamStore, snowflakeConfiguration = snowflakeConfiguration, - snowflakeColumnUtils = snowflakeColumnUtils, + catalog = catalog, + columnManager = columnManager, + snowflakeRecordFormatter = snowflakeRecordFormatter, ) val aggregate = factory.create(key) assertNotNull(aggregate) @@ -52,26 +63,33 @@ internal class SnowflakeAggregateFactoryTest { @Test fun testCreatingAggregateWithStagingBuffer() { val descriptor = Descriptor(namespace = "namespace", name = "name") - val directLoadTableExecutionConfig = - DirectLoadTableExecutionConfig( - tableName = - TableName( - namespace = descriptor.namespace!!, - name = descriptor.name, - ) + val tableName = + TableName( + namespace = descriptor.namespace!!, + name = descriptor.name, ) + val directLoadTableExecutionConfig = DirectLoadTableExecutionConfig(tableName = tableName) val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name) val streamStore = StreamStateStore() - streamStore.put(descriptor, directLoadTableExecutionConfig) + streamStore.put(key, directLoadTableExecutionConfig) + + val stream = mockk(relaxed = true) + val catalog = mockk { every { getStream(key) } returns stream } + val snowflakeClient = mockk(relaxed = true) - val snowflakeConfiguration = mockk(relaxed = true) - val snowflakeColumnUtils = mockk(relaxed = true) + val snowflakeConfiguration = + mockk { every { legacyRawTablesOnly } returns false } + val columnManager = SnowflakeColumnManager(snowflakeConfiguration) + val snowflakeRecordFormatter: SnowflakeRecordFormatter = SnowflakeSchemaRecordFormatter() + val factory = SnowflakeAggregateFactory( snowflakeClient = snowflakeClient, streamStateStore = streamStore, snowflakeConfiguration = snowflakeConfiguration, - snowflakeColumnUtils = snowflakeColumnUtils, + catalog = catalog, + columnManager = columnManager, + snowflakeRecordFormatter = snowflakeRecordFormatter, ) val aggregate = factory.create(key) assertNotNull(aggregate) diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/db/SnowflakeColumnNameGeneratorTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/db/SnowflakeColumnNameGeneratorTest.kt deleted file mode 100644 index fcd59965e16..00000000000 --- a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/db/SnowflakeColumnNameGeneratorTest.kt +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (c) 2025 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.integrations.destination.snowflake.db - -import io.mockk.every -import io.mockk.mockk -import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.Test - -internal class SnowflakeColumnNameGeneratorTest { - - @Test - fun testGetColumnName() { - val column = "test-column" - val generator = - SnowflakeColumnNameGenerator(mockk { every { legacyRawTablesOnly } returns false }) - val columnName = generator.getColumnName(column) - assertEquals(column.toSnowflakeCompatibleName(), columnName.displayName) - assertEquals(column.toSnowflakeCompatibleName(), columnName.canonicalName) - } -} diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/db/SnowflakeFinalTableNameGeneratorTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/db/SnowflakeFinalTableNameGeneratorTest.kt deleted file mode 100644 index 066d4e2c192..00000000000 --- a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/db/SnowflakeFinalTableNameGeneratorTest.kt +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright (c) 2025 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.integrations.destination.snowflake.db - -import io.airbyte.cdk.load.command.DestinationStream -import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration -import io.mockk.every -import io.mockk.mockk -import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.Test - -internal class SnowflakeFinalTableNameGeneratorTest { - - @Test - fun testGetTableNameWithInternalNamespace() { - val configuration = - mockk { - every { internalTableSchema } returns "test-internal-namespace" - every { legacyRawTablesOnly } returns true - } - val generator = SnowflakeFinalTableNameGenerator(config = configuration) - val streamName = "test-stream-name" - val streamNamespace = "test-stream-namespace" - val streamDescriptor = - mockk { - every { namespace } returns streamNamespace - every { name } returns streamName - } - val tableName = generator.getTableName(streamDescriptor) - assertEquals("test-stream-namespace_raw__stream_test-stream-name", tableName.name) - assertEquals("test-internal-namespace", tableName.namespace) - } - - @Test - fun testGetTableNameWithNamespace() { - val configuration = - mockk { every { legacyRawTablesOnly } returns false } - val generator = SnowflakeFinalTableNameGenerator(config = configuration) - val streamName = "test-stream-name" - val streamNamespace = "test-stream-namespace" - val streamDescriptor = - mockk { - every { namespace } returns streamNamespace - every { name } returns streamName - } - val tableName = generator.getTableName(streamDescriptor) - assertEquals("TEST-STREAM-NAME", tableName.name) - assertEquals("TEST-STREAM-NAMESPACE", tableName.namespace) - } - - @Test - fun testGetTableNameWithDefaultNamespace() { - val defaultNamespace = "test-default-namespace" - val configuration = - mockk { - every { schema } returns defaultNamespace - every { legacyRawTablesOnly } returns false - } - val generator = SnowflakeFinalTableNameGenerator(config = configuration) - val streamName = "test-stream-name" - val streamDescriptor = - mockk { - every { namespace } returns null - every { name } returns streamName - } - val tableName = generator.getTableName(streamDescriptor) - assertEquals("TEST-STREAM-NAME", tableName.name) - assertEquals("TEST-DEFAULT-NAMESPACE", tableName.namespace) - } -} diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/db/SnowflakeNameGeneratorsTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/db/SnowflakeNameGeneratorsTest.kt deleted file mode 100644 index 28d331ae0bd..00000000000 --- a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/db/SnowflakeNameGeneratorsTest.kt +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2025 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.integrations.destination.snowflake.db - -import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.params.ParameterizedTest -import org.junit.jupiter.params.provider.CsvSource - -internal class SnowflakeNameGeneratorsTest { - - @ParameterizedTest - @CsvSource( - value = - [ - "test-name,TEST-NAME", - "1-test-name,1-TEST-NAME", - "test-name!!!,TEST-NAME!!!", - "test\${name,TEST__NAME", - "test\"name,TEST\"\"NAME", - ] - ) - fun testToSnowflakeCompatibleName(name: String, expected: String) { - assertEquals(expected, name.toSnowflakeCompatibleName()) - } -} diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeColumnUtilsTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeColumnUtilsTest.kt deleted file mode 100644 index 92ab964d6f1..00000000000 --- a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeColumnUtilsTest.kt +++ /dev/null @@ -1,358 +0,0 @@ -/* - * Copyright (c) 2025 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.integrations.destination.snowflake.sql - -import com.fasterxml.jackson.databind.JsonNode -import io.airbyte.cdk.load.data.ArrayType -import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema -import io.airbyte.cdk.load.data.BooleanType -import io.airbyte.cdk.load.data.DateType -import io.airbyte.cdk.load.data.FieldType -import io.airbyte.cdk.load.data.IntegerType -import io.airbyte.cdk.load.data.NumberType -import io.airbyte.cdk.load.data.ObjectType -import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema -import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema -import io.airbyte.cdk.load.data.StringType -import io.airbyte.cdk.load.data.TimeTypeWithTimezone -import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone -import io.airbyte.cdk.load.data.TimestampTypeWithTimezone -import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone -import io.airbyte.cdk.load.data.UnionType -import io.airbyte.cdk.load.data.UnknownType -import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA -import io.airbyte.cdk.load.orchestration.db.ColumnNameGenerator -import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN -import io.airbyte.cdk.load.table.ColumnNameMapping -import io.airbyte.integrations.destination.snowflake.db.SnowflakeColumnNameGenerator -import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName -import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration -import io.mockk.every -import io.mockk.mockk -import kotlin.collections.LinkedHashMap -import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Test -import org.junit.jupiter.params.ParameterizedTest -import org.junit.jupiter.params.provider.CsvSource - -internal class SnowflakeColumnUtilsTest { - - private lateinit var snowflakeConfiguration: SnowflakeConfiguration - private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils - private lateinit var snowflakeColumnNameGenerator: SnowflakeColumnNameGenerator - - @BeforeEach - fun setup() { - snowflakeConfiguration = mockk(relaxed = true) - snowflakeColumnNameGenerator = - mockk(relaxed = true) { - every { getColumnName(any()) } answers - { - val displayName = - if (snowflakeConfiguration.legacyRawTablesOnly) firstArg() - else firstArg().toSnowflakeCompatibleName() - val canonicalName = - if (snowflakeConfiguration.legacyRawTablesOnly) firstArg() - else firstArg().toSnowflakeCompatibleName() - ColumnNameGenerator.ColumnName( - displayName = displayName, - canonicalName = canonicalName, - ) - } - } - snowflakeColumnUtils = - SnowflakeColumnUtils(snowflakeConfiguration, snowflakeColumnNameGenerator) - } - - @Test - fun testDefaultColumns() { - val expectedDefaultColumns = DEFAULT_COLUMNS - assertEquals(expectedDefaultColumns, snowflakeColumnUtils.defaultColumns()) - } - - @Test - fun testDefaultRawColumns() { - every { snowflakeConfiguration.legacyRawTablesOnly } returns true - - val expectedDefaultColumns = DEFAULT_COLUMNS + RAW_COLUMNS - - assertEquals(expectedDefaultColumns, snowflakeColumnUtils.defaultColumns()) - } - - @Test - fun testGetFormattedDefaultColumnNames() { - val expectedDefaultColumnNames = - DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() } - val defaultColumnNames = snowflakeColumnUtils.getFormattedDefaultColumnNames() - assertEquals(expectedDefaultColumnNames, defaultColumnNames) - } - - @Test - fun testGetFormattedDefaultColumnNamesQuoted() { - val expectedDefaultColumnNames = - DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName().quote() } - val defaultColumnNames = snowflakeColumnUtils.getFormattedDefaultColumnNames(true) - assertEquals(expectedDefaultColumnNames, defaultColumnNames) - } - - @Test - fun testGetColumnName() { - val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual")) - val columnNames = snowflakeColumnUtils.getColumnNames(columnNameMapping) - val expectedColumnNames = - (DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() } + listOf("actual")) - .joinToString(",") { it.quote() } - assertEquals(expectedColumnNames, columnNames) - } - - @Test - fun testGetRawColumnName() { - every { snowflakeConfiguration.legacyRawTablesOnly } returns true - - val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual")) - val columnNames = snowflakeColumnUtils.getColumnNames(columnNameMapping) - val expectedColumnNames = - (DEFAULT_COLUMNS.map { it.columnName } + RAW_COLUMNS.map { it.columnName }) - .joinToString(",") { it.quote() } - assertEquals(expectedColumnNames, columnNames) - } - - @Test - fun testGetRawFormattedColumnNames() { - every { snowflakeConfiguration.legacyRawTablesOnly } returns true - val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual")) - val schemaColumns = - mapOf( - "column_one" to FieldType(StringType, true), - "column_two" to FieldType(IntegerType, true), - "original" to FieldType(StringType, true), - CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true) - ) - val expectedColumnNames = - DEFAULT_COLUMNS.map { it.columnName.quote() } + - RAW_COLUMNS.map { it.columnName.quote() } - - val columnNames = - snowflakeColumnUtils.getFormattedColumnNames( - columns = schemaColumns, - columnNameMapping = columnNameMapping - ) - assertEquals(expectedColumnNames.size, columnNames.size) - assertEquals(expectedColumnNames.sorted(), columnNames.sorted()) - } - - @Test - fun testGetFormattedColumnNames() { - val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual")) - val schemaColumns = - mapOf( - "column_one" to FieldType(StringType, true), - "column_two" to FieldType(IntegerType, true), - "original" to FieldType(StringType, true), - CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true) - ) - val expectedColumnNames = - listOf( - "actual", - "column_one", - "column_two", - CDC_DELETED_AT_COLUMN, - ) - .map { it.quote() } + - DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName().quote() } - val columnNames = - snowflakeColumnUtils.getFormattedColumnNames( - columns = schemaColumns, - columnNameMapping = columnNameMapping - ) - assertEquals(expectedColumnNames.size, columnNames.size) - assertEquals(expectedColumnNames.sorted(), columnNames.sorted()) - } - - @Test - fun testGetFormattedColumnNamesNoQuotes() { - val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual")) - val schemaColumns = - mapOf( - "column_one" to FieldType(StringType, true), - "column_two" to FieldType(IntegerType, true), - "original" to FieldType(StringType, true), - CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true) - ) - val expectedColumnNames = - listOf( - "actual", - "column_one", - "column_two", - CDC_DELETED_AT_COLUMN, - ) + DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() } - val columnNames = - snowflakeColumnUtils.getFormattedColumnNames( - columns = schemaColumns, - columnNameMapping = columnNameMapping, - quote = false - ) - assertEquals(expectedColumnNames.size, columnNames.size) - assertEquals(expectedColumnNames.sorted(), columnNames.sorted()) - } - - @Test - fun testGeneratingRawTableColumnsAndTypesNoColumnMapping() { - every { snowflakeConfiguration.legacyRawTablesOnly } returns true - - val columns = - snowflakeColumnUtils.columnsAndTypes( - columns = emptyMap(), - columnNameMapping = ColumnNameMapping(emptyMap()) - ) - assertEquals(DEFAULT_COLUMNS.size + RAW_COLUMNS.size, columns.size) - assertEquals( - "${SnowflakeDataType.VARIANT.typeName} $NOT_NULL", - columns.find { it.columnName == RAW_DATA_COLUMN.columnName }?.columnType - ) - } - - @Test - fun testGeneratingColumnsAndTypesNoColumnMapping() { - val columnName = "test-column" - val fieldType = FieldType(StringType, false) - val declaredColumns = mapOf(columnName to fieldType) - - val columns = - snowflakeColumnUtils.columnsAndTypes( - columns = declaredColumns, - columnNameMapping = ColumnNameMapping(emptyMap()) - ) - assertEquals(DEFAULT_COLUMNS.size + 1, columns.size) - assertEquals( - "${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL", - columns.find { it.columnName == columnName }?.columnType - ) - } - - @Test - fun testGeneratingColumnsAndTypesWithColumnMapping() { - val columnName = "test-column" - val mappedColumnName = "mapped-column-name" - val fieldType = FieldType(StringType, false) - val declaredColumns = mapOf(columnName to fieldType) - val columnNameMapping = ColumnNameMapping(mapOf(columnName to mappedColumnName)) - - val columns = - snowflakeColumnUtils.columnsAndTypes( - columns = declaredColumns, - columnNameMapping = columnNameMapping - ) - assertEquals(DEFAULT_COLUMNS.size + 1, columns.size) - assertEquals( - "${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL", - columns.find { it.columnName == mappedColumnName }?.columnType - ) - } - - @Test - fun testToDialectType() { - assertEquals( - SnowflakeDataType.BOOLEAN.typeName, - snowflakeColumnUtils.toDialectType(BooleanType) - ) - assertEquals(SnowflakeDataType.DATE.typeName, snowflakeColumnUtils.toDialectType(DateType)) - assertEquals( - SnowflakeDataType.NUMBER.typeName, - snowflakeColumnUtils.toDialectType(IntegerType) - ) - assertEquals( - SnowflakeDataType.FLOAT.typeName, - snowflakeColumnUtils.toDialectType(NumberType) - ) - assertEquals( - SnowflakeDataType.VARCHAR.typeName, - snowflakeColumnUtils.toDialectType(StringType) - ) - assertEquals( - SnowflakeDataType.VARCHAR.typeName, - snowflakeColumnUtils.toDialectType(TimeTypeWithTimezone) - ) - assertEquals( - SnowflakeDataType.TIME.typeName, - snowflakeColumnUtils.toDialectType(TimeTypeWithoutTimezone) - ) - assertEquals( - SnowflakeDataType.TIMESTAMP_TZ.typeName, - snowflakeColumnUtils.toDialectType(TimestampTypeWithTimezone) - ) - assertEquals( - SnowflakeDataType.TIMESTAMP_NTZ.typeName, - snowflakeColumnUtils.toDialectType(TimestampTypeWithoutTimezone) - ) - assertEquals( - SnowflakeDataType.ARRAY.typeName, - snowflakeColumnUtils.toDialectType(ArrayType(items = FieldType(StringType, false))) - ) - assertEquals( - SnowflakeDataType.ARRAY.typeName, - snowflakeColumnUtils.toDialectType(ArrayTypeWithoutSchema) - ) - assertEquals( - SnowflakeDataType.OBJECT.typeName, - snowflakeColumnUtils.toDialectType( - ObjectType( - properties = LinkedHashMap(), - additionalProperties = false, - ) - ) - ) - assertEquals( - SnowflakeDataType.OBJECT.typeName, - snowflakeColumnUtils.toDialectType(ObjectTypeWithEmptySchema) - ) - assertEquals( - SnowflakeDataType.OBJECT.typeName, - snowflakeColumnUtils.toDialectType(ObjectTypeWithoutSchema) - ) - assertEquals( - SnowflakeDataType.VARIANT.typeName, - snowflakeColumnUtils.toDialectType( - UnionType( - options = setOf(StringType), - isLegacyUnion = true, - ) - ) - ) - assertEquals( - SnowflakeDataType.VARIANT.typeName, - snowflakeColumnUtils.toDialectType( - UnionType( - options = emptySet(), - isLegacyUnion = false, - ) - ) - ) - assertEquals( - SnowflakeDataType.VARIANT.typeName, - snowflakeColumnUtils.toDialectType(UnknownType(schema = mockk())) - ) - } - - @ParameterizedTest - @CsvSource( - value = - [ - "$COLUMN_NAME_DATA, true, \"$COLUMN_NAME_DATA\"", - "some-other_Column, true, \"SOME-OTHER_COLUMN\"", - "$COLUMN_NAME_DATA, false, $COLUMN_NAME_DATA", - "some-other_Column, false, SOME-OTHER_COLUMN", - "$COLUMN_NAME_DATA, true, \"$COLUMN_NAME_DATA\"", - "some-other_Column, true, \"SOME-OTHER_COLUMN\"", - ] - ) - fun testFormatColumnName(columnName: String, quote: Boolean, expectedFormattedName: String) { - assertEquals( - expectedFormattedName, - snowflakeColumnUtils.formatColumnName(columnName, quote) - ) - } -} diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeDirectLoadSqlGeneratorTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeDirectLoadSqlGeneratorTest.kt index b0676c48d03..ab83bab85df 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeDirectLoadSqlGeneratorTest.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeDirectLoadSqlGeneratorTest.kt @@ -4,22 +4,21 @@ package io.airbyte.integrations.destination.snowflake.sql +import io.airbyte.cdk.load.command.Append import io.airbyte.cdk.load.command.Dedupe -import io.airbyte.cdk.load.command.DestinationStream import io.airbyte.cdk.load.component.ColumnType import io.airbyte.cdk.load.component.ColumnTypeChange import io.airbyte.cdk.load.data.FieldType -import io.airbyte.cdk.load.data.ObjectType import io.airbyte.cdk.load.data.StringType -import io.airbyte.cdk.load.data.TimestampTypeWithTimezone -import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID -import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA +import io.airbyte.cdk.load.schema.model.ColumnSchema +import io.airbyte.cdk.load.schema.model.StreamTableSchema +import io.airbyte.cdk.load.schema.model.TableName +import io.airbyte.cdk.load.schema.model.TableNames import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN -import io.airbyte.cdk.load.table.ColumnNameMapping -import io.airbyte.cdk.load.table.TableName import io.airbyte.cdk.load.util.UUIDGenerator -import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName +import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager +import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.write.load.CSV_FIELD_SEPARATOR @@ -29,41 +28,40 @@ import io.mockk.mockk import java.util.UUID import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertFalse -import org.junit.jupiter.api.Assertions.assertThrows import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test +/** TODO: These tests are somewhat dubious. */ internal class SnowflakeDirectLoadSqlGeneratorTest { - private lateinit var columnUtils: SnowflakeColumnUtils private lateinit var snowflakeDirectLoadSqlGenerator: SnowflakeDirectLoadSqlGenerator private val uuidGenerator: UUIDGenerator = mockk() private val snowflakeConfiguration: SnowflakeConfiguration = mockk() - private lateinit var snowflakeSqlNameUtils: SnowflakeSqlNameUtils + private val snowflakeColumnManager: SnowflakeColumnManager = mockk() @BeforeEach fun setUp() { every { snowflakeConfiguration.cdcDeletionMode } returns CdcDeletionMode.HARD_DELETE every { snowflakeConfiguration.database } returns "test-database" every { snowflakeConfiguration.legacyRawTablesOnly } returns false - columnUtils = mockk { - every { formatColumnName(any()) } answers - { - val columnName = firstArg() - if (columnName == COLUMN_NAME_DATA) columnName - else columnName.toSnowflakeCompatibleName() - } - every { getGenerationIdColumnName() } returns - COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName() - } - snowflakeSqlNameUtils = SnowflakeSqlNameUtils(snowflakeConfiguration) + + every { snowflakeColumnManager.getMetaColumns() } returns + linkedMapOf( + SNOWFLAKE_AB_RAW_ID to ColumnType("VARCHAR", false), + SNOWFLAKE_AB_EXTRACTED_AT to ColumnType("TIMESTAMP_TZ", false), + SNOWFLAKE_AB_META to ColumnType("VARIANT", false), + SNOWFLAKE_AB_GENERATION_ID to ColumnType("NUMBER", true), + ) + + every { snowflakeColumnManager.getGenerationIdColumnName() } returns + SNOWFLAKE_AB_GENERATION_ID + snowflakeDirectLoadSqlGenerator = SnowflakeDirectLoadSqlGenerator( - columnUtils = columnUtils, uuidGenerator = uuidGenerator, - snowflakeConfiguration = snowflakeConfiguration, - snowflakeSqlNameUtils = snowflakeSqlNameUtils, + config = snowflakeConfiguration, + columnManager = snowflakeColumnManager, ) } @@ -72,7 +70,7 @@ internal class SnowflakeDirectLoadSqlGeneratorTest { val tableName = TableName(namespace = "namespace", name = "name") val sql = snowflakeDirectLoadSqlGenerator.countTable(tableName) assertEquals( - "SELECT COUNT(*) AS TOTAL FROM ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}", + "SELECT COUNT(*) AS TOTAL FROM ${snowflakeDirectLoadSqlGenerator.fullyQualifiedName(tableName)}", sql ) } @@ -82,57 +80,82 @@ internal class SnowflakeDirectLoadSqlGeneratorTest { val namespace = "namespace" val sql = snowflakeDirectLoadSqlGenerator.createNamespace(namespace) assertEquals( - "CREATE SCHEMA IF NOT EXISTS ${snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)}", + "CREATE SCHEMA IF NOT EXISTS ${snowflakeDirectLoadSqlGenerator.fullyQualifiedNamespace(namespace)}", sql ) } @Test fun testGenerateCreateTableStatement() { - val columnAndType = - ColumnAndType(columnName = "column-name", columnType = "VARCHAR NOT NULL") - val columnNameMapping = mockk(relaxed = true) - val stream = mockk(relaxed = true) val tableName = TableName(namespace = "namespace", name = "name") - - every { columnUtils.columnsAndTypes(any(), columnNameMapping) } returns - listOf(columnAndType) + val tableSchema = + StreamTableSchema( + tableNames = TableNames(finalTableName = tableName, tempTableName = tableName), + columnSchema = + ColumnSchema( + inputToFinalColumnNames = mapOf("column" to "COLUMN_NAME"), + finalSchema = mapOf("COLUMN_NAME" to ColumnType("VARCHAR", false)), + inputSchema = mapOf("column" to FieldType(StringType, nullable = false)) + ), + importType = Append + ) val sql = snowflakeDirectLoadSqlGenerator.createTable( - stream = stream, tableName = tableName, - columnNameMapping = columnNameMapping, + tableSchema = tableSchema, replace = true ) - assertEquals( - "CREATE OR REPLACE TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)} (\n $columnAndType\n)", - sql - ) + + // The expected SQL should match the exact format + val expectedTableName = snowflakeDirectLoadSqlGenerator.fullyQualifiedName(tableName) + val expectedSql = + """ + CREATE OR REPLACE TABLE $expectedTableName ( + "_AIRBYTE_RAW_ID" VARCHAR NOT NULL, + "_AIRBYTE_EXTRACTED_AT" TIMESTAMP_TZ NOT NULL, + "_AIRBYTE_META" VARIANT NOT NULL, + "_AIRBYTE_GENERATION_ID" NUMBER, + "COLUMN_NAME" VARCHAR NOT NULL + ) + """.trimIndent() + assertEquals(expectedSql, sql) } @Test fun testGenerateCreateTableStatementNoReplace() { - val columnAndType = - ColumnAndType(columnName = "column-name", columnType = "VARCHAR NOT NULL") - val columnNameMapping = mockk(relaxed = true) - val stream = mockk(relaxed = true) val tableName = TableName(namespace = "namespace", name = "name") - - every { columnUtils.columnsAndTypes(any(), columnNameMapping) } returns - listOf(columnAndType) + val tableSchema = + StreamTableSchema( + tableNames = TableNames(finalTableName = tableName, tempTableName = tableName), + columnSchema = + ColumnSchema( + inputToFinalColumnNames = mapOf("column" to "COLUMN_NAME"), + finalSchema = mapOf("COLUMN_NAME" to ColumnType("VARCHAR", false)), + inputSchema = mapOf("column" to FieldType(StringType, nullable = false)) + ), + importType = Append + ) val sql = snowflakeDirectLoadSqlGenerator.createTable( - stream = stream, tableName = tableName, - columnNameMapping = columnNameMapping, + tableSchema = tableSchema, replace = false ) - assertEquals( - "CREATE TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)} (\n $columnAndType\n)", - sql - ) + + val expectedTableName = snowflakeDirectLoadSqlGenerator.fullyQualifiedName(tableName) + val expectedSql = + """ + CREATE TABLE $expectedTableName ( + "_AIRBYTE_RAW_ID" VARCHAR NOT NULL, + "_AIRBYTE_EXTRACTED_AT" TIMESTAMP_TZ NOT NULL, + "_AIRBYTE_META" VARIANT NOT NULL, + "_AIRBYTE_GENERATION_ID" NUMBER, + "COLUMN_NAME" VARCHAR NOT NULL + ) + """.trimIndent() + assertEquals(expectedSql, sql) } @Test @@ -140,118 +163,42 @@ internal class SnowflakeDirectLoadSqlGeneratorTest { val tableName = TableName(namespace = "namespace", name = "name") val sql = snowflakeDirectLoadSqlGenerator.showColumns(tableName) assertEquals( - "SHOW COLUMNS IN TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}", + "SHOW COLUMNS IN TABLE ${snowflakeDirectLoadSqlGenerator.fullyQualifiedName(tableName)}", sql ) } @Test fun testGenerateCopyTable() { - val columnName = "column-name" - val mappedColumnName = "mapped-column-name" - val columns = mapOf(columnName to mappedColumnName) - val columnNameMapping = ColumnNameMapping(columns) val columnNames = - (DEFAULT_COLUMNS.map { it.columnName } + "mappedColumnName").joinToString(",") { - it.toSnowflakeCompatibleName().quote() - } + setOf( + "_AIRBYTE_RAW_ID", + "_AIRBYTE_EXTRACTED_AT", + "_AIRBYTE_META", + "_AIRBYTE_GENERATION_ID", + "MAPPED_COLUMN_NAME" + ) val sourceTableName = TableName(namespace = "namespace", name = "source") val destinationTableName = TableName(namespace = "namespace", name = "destination") - every { columnUtils.getColumnNames(columnNameMapping) } returns columnNames - - val expected = - """ - INSERT INTO ${snowflakeSqlNameUtils.fullyQualifiedName(destinationTableName)} - ( - $columnNames - ) - SELECT - $columnNames - FROM ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} - """.trimIndent() val sql = snowflakeDirectLoadSqlGenerator.copyTable( - columnNameMapping = columnNameMapping, + columnNames = columnNames, sourceTableName = sourceTableName, targetTableName = destinationTableName, ) - assertEquals(expected, sql) - } - @Test - fun testGenerateUpsertTable() { - val primaryKey = listOf(listOf("primaryKey")) - val cursor = listOf("cursor") - val stream = - mockk { - every { importType } returns - Dedupe( - primaryKey = primaryKey, - cursor = cursor, - ) - every { schema } returns StringType - } - val columnNameMapping = ColumnNameMapping(emptyMap()) - val sourceTableName = TableName(namespace = "namespace", name = "source") - val destinationTableName = TableName(namespace = "namespace", name = "destination") - val expectedColumns = DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() } - - every { columnUtils.getFormattedColumnNames(any(), columnNameMapping) } returns - expectedColumns.map { it.quote() } - every { columnUtils.getFormattedColumnNames(any(), columnNameMapping, false) } returns - expectedColumns - - val expectedDestinationTable = - "${snowflakeConfiguration.database.toSnowflakeCompatibleName().quote()}.${destinationTableName.namespace.quote()}.${destinationTableName.name.quote()}" - val expectedSourceTable = - "${snowflakeConfiguration.database.toSnowflakeCompatibleName().quote()}.${sourceTableName.namespace.quote()}.${sourceTableName.name.quote()}" + val columnList = columnNames.joinToString(", ") { "\"$it\"" } val expected = """ - MERGE INTO $expectedDestinationTable AS target_table - USING ( - WITH records AS ( - SELECT - ${expectedColumns.joinToString(",\n") { it.quote() } } - FROM $expectedSourceTable - ), numbered_rows AS ( - SELECT *, ROW_NUMBER() OVER ( - PARTITION BY "primaryKey" ORDER BY "cursor" DESC NULLS LAST, "_AIRBYTE_EXTRACTED_AT" DESC - ) AS row_number - FROM records - ) - SELECT ${expectedColumns.joinToString(",\n") { it.quote() } } - FROM numbered_rows - WHERE row_number = 1 - ) AS new_record - ON (target_table."primaryKey" = new_record."primaryKey" OR (target_table."primaryKey" IS NULL AND new_record."primaryKey" IS NULL)) - WHEN MATCHED AND ( - target_table."cursor" < new_record."cursor" - OR (target_table."cursor" = new_record."cursor" AND target_table."_AIRBYTE_EXTRACTED_AT" < new_record."_AIRBYTE_EXTRACTED_AT") - OR (target_table."cursor" IS NULL AND new_record."cursor" IS NULL AND target_table."_AIRBYTE_EXTRACTED_AT" < new_record."_AIRBYTE_EXTRACTED_AT") - OR (target_table."cursor" IS NULL AND new_record."cursor" IS NOT NULL) -) THEN UPDATE SET - "_AIRBYTE_RAW_ID" = new_record."_AIRBYTE_RAW_ID", -"_AIRBYTE_EXTRACTED_AT" = new_record."_AIRBYTE_EXTRACTED_AT", -"_AIRBYTE_META" = new_record."_AIRBYTE_META", -"_AIRBYTE_GENERATION_ID" = new_record."_AIRBYTE_GENERATION_ID" - WHEN NOT MATCHED THEN INSERT ( - ${expectedColumns.joinToString(",\n") { it.quote() } } - ) VALUES ( - new_record."_AIRBYTE_RAW_ID", -new_record."_AIRBYTE_EXTRACTED_AT", -new_record."_AIRBYTE_META", -new_record."_AIRBYTE_GENERATION_ID" - ) - """.trimIndent() - - val sql = - snowflakeDirectLoadSqlGenerator.upsertTable( - stream = stream, - columnNameMapping = columnNameMapping, - sourceTableName = sourceTableName, - targetTableName = destinationTableName, + INSERT INTO ${snowflakeDirectLoadSqlGenerator.fullyQualifiedName(destinationTableName)} + ( + $columnList ) + SELECT + $columnList + FROM ${snowflakeDirectLoadSqlGenerator.fullyQualifiedName(sourceTableName)} + """.trimIndent() assertEquals(expected, sql) } @@ -260,7 +207,7 @@ new_record."_AIRBYTE_GENERATION_ID" val tableName = TableName(namespace = "namespace", name = "name") val sql = snowflakeDirectLoadSqlGenerator.dropTable(tableName) assertEquals( - "DROP TABLE IF EXISTS ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}", + "DROP TABLE IF EXISTS ${snowflakeDirectLoadSqlGenerator.fullyQualifiedName(tableName)}", sql ) } @@ -272,7 +219,7 @@ new_record."_AIRBYTE_GENERATION_ID" val expectedSql = """ SELECT "${COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName()}" - FROM ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)} + FROM ${snowflakeDirectLoadSqlGenerator.fullyQualifiedName(tableName)} LIMIT 1 """.trimIndent() assertEquals(expectedSql, sql) @@ -281,7 +228,7 @@ new_record."_AIRBYTE_GENERATION_ID" @Test fun testGenerateCreateStage() { val tableName = TableName(namespace = "namespace", name = "name") - val stagingTableName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName) + val stagingTableName = snowflakeDirectLoadSqlGenerator.fullyQualifiedStageName(tableName) val sql = snowflakeDirectLoadSqlGenerator.createSnowflakeStage(tableName) assertEquals("CREATE STAGE IF NOT EXISTS $stagingTableName", sql) } @@ -290,7 +237,7 @@ new_record."_AIRBYTE_GENERATION_ID" fun testGeneratePutInStage() { val tableName = TableName(namespace = "namespace", name = "name") val tempFilePath = "/some/file/path.csv" - val stagingTableName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName) + val stagingTableName = snowflakeDirectLoadSqlGenerator.fullyQualifiedStageName(tableName) val sql = snowflakeDirectLoadSqlGenerator.putInStage(tableName, tempFilePath) val expectedSql = """ @@ -305,316 +252,110 @@ new_record."_AIRBYTE_GENERATION_ID" @Test fun testGenerateCopyFromStage() { val tableName = TableName(namespace = "namespace", name = "name") - val targetTableName = snowflakeSqlNameUtils.fullyQualifiedName(tableName) - val stagingTableName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName) + val targetTableName = snowflakeDirectLoadSqlGenerator.fullyQualifiedName(tableName) + val stagingTableName = snowflakeDirectLoadSqlGenerator.fullyQualifiedStageName(tableName) val sql = snowflakeDirectLoadSqlGenerator.copyFromStage(tableName, "test.csv.gz") val expectedSql = """ - COPY INTO $targetTableName - FROM '@$stagingTableName' - FILE_FORMAT = ( - TYPE = 'CSV' - COMPRESSION = GZIP - FIELD_DELIMITER = '$CSV_FIELD_SEPARATOR' - RECORD_DELIMITER = '$CSV_LINE_DELIMITER' - FIELD_OPTIONALLY_ENCLOSED_BY = '"' - TRIM_SPACE = TRUE - ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE - REPLACE_INVALID_CHARACTERS = TRUE - ESCAPE = NONE - ESCAPE_UNENCLOSED_FIELD = NONE - ) - ON_ERROR = 'ABORT_STATEMENT' - PURGE = TRUE - files = ('test.csv.gz') - """.trimIndent() + |COPY INTO $targetTableName + |FROM '@$stagingTableName' + |FILE_FORMAT = ( + | TYPE = 'CSV' + | COMPRESSION = GZIP + | FIELD_DELIMITER = '$CSV_FIELD_SEPARATOR' + | RECORD_DELIMITER = '$CSV_LINE_DELIMITER' + | FIELD_OPTIONALLY_ENCLOSED_BY = '"' + | TRIM_SPACE = TRUE + | ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE + | REPLACE_INVALID_CHARACTERS = TRUE + | ESCAPE = NONE + | ESCAPE_UNENCLOSED_FIELD = NONE + |) + |ON_ERROR = 'ABORT_STATEMENT' + |PURGE = TRUE + |files = ('test.csv.gz') + """.trimMargin() assertEquals(expectedSql, sql) } @Test - fun testGenerateUpsertTableWithCdcHardDelete() { - // Test with CDC hard delete mode when _ab_cdc_deleted_at column is present - val primaryKey = listOf(listOf("id")) - val cursor = listOf("updated_at") - - // Create schema with CDC deletion column - val schemaWithCdc = - ObjectType( - properties = - linkedMapOf( - "id" to FieldType(StringType, nullable = false), - "name" to FieldType(StringType, nullable = true), - "updated_at" to FieldType(TimestampTypeWithTimezone, nullable = true), - CDC_DELETED_AT_COLUMN to - FieldType(TimestampTypeWithTimezone, nullable = true) - ) - ) - - val stream = - mockk { - every { importType } returns - Dedupe( - primaryKey = primaryKey, - cursor = cursor, - ) - every { schema } returns schemaWithCdc - } - - val columnNameMapping = - ColumnNameMapping( - mapOf( - "id" to "id", - "name" to "name", - "updated_at" to "updated_at", - CDC_DELETED_AT_COLUMN to "_ab_cdc_deleted_at" - ) - ) - val sourceTableName = TableName(namespace = "test_ns", name = "source") - val targetTableName = TableName(namespace = "test_ns", name = "target") - - every { columnUtils.getFormattedColumnNames(any(), columnNameMapping, any()) } returns + fun testGenerateCopyFromStageWithColumnList() { + val tableName = TableName(namespace = "namespace", name = "name") + val targetTableName = snowflakeDirectLoadSqlGenerator.fullyQualifiedName(tableName) + val stagingTableName = snowflakeDirectLoadSqlGenerator.fullyQualifiedStageName(tableName) + // Test with uppercase column names (schema mode) + val schemaColumns = listOf( - "id", - "name", - "updated_at", - CDC_DELETED_AT_COLUMN, - ) - .map { it.toSnowflakeCompatibleName().quote() } + - DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName().quote() } - + "_AIRBYTE_RAW_ID", + "_AIRBYTE_EXTRACTED_AT", + "_AIRBYTE_META", + "_AIRBYTE_GENERATION_ID", + "COL1", + "COL2" + ) val sql = - snowflakeDirectLoadSqlGenerator.upsertTable( - stream = stream, - columnNameMapping = columnNameMapping, - sourceTableName = sourceTableName, - targetTableName = targetTableName, - ) - - // Should include the DELETE clause and skip insert clause - assert( - sql.contains( - "WHEN MATCHED AND new_record.${CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName().quote()} IS NOT NULL" - ) - ) - assert(sql.contains("THEN DELETE")) - assert( - sql.contains( - "WHEN NOT MATCHED AND new_record.${CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName().quote()} IS NULL THEN INSERT" - ) - ) + snowflakeDirectLoadSqlGenerator.copyFromStage(tableName, "test.csv.gz", schemaColumns) + val expectedSql = + """ + |COPY INTO $targetTableName("_AIRBYTE_RAW_ID", "_AIRBYTE_EXTRACTED_AT", "_AIRBYTE_META", "_AIRBYTE_GENERATION_ID", "COL1", "COL2") + |FROM '@$stagingTableName' + |FILE_FORMAT = ( + | TYPE = 'CSV' + | COMPRESSION = GZIP + | FIELD_DELIMITER = '$CSV_FIELD_SEPARATOR' + | RECORD_DELIMITER = '$CSV_LINE_DELIMITER' + | FIELD_OPTIONALLY_ENCLOSED_BY = '"' + | TRIM_SPACE = TRUE + | ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE + | REPLACE_INVALID_CHARACTERS = TRUE + | ESCAPE = NONE + | ESCAPE_UNENCLOSED_FIELD = NONE + |) + |ON_ERROR = 'ABORT_STATEMENT' + |PURGE = TRUE + |files = ('test.csv.gz') + """.trimMargin() + assertEquals(expectedSql, sql) } @Test - fun testGenerateUpsertTableWithCdcSoftDelete() { - // Test with CDC soft delete mode - should NOT add delete clauses - every { snowflakeConfiguration.cdcDeletionMode } returns CdcDeletionMode.SOFT_DELETE - val softDeleteGenerator = - SnowflakeDirectLoadSqlGenerator( - columnUtils = columnUtils, - uuidGenerator = uuidGenerator, - snowflakeConfiguration = snowflakeConfiguration, - snowflakeSqlNameUtils = snowflakeSqlNameUtils, - ) - - val primaryKey = listOf(listOf("id")) - val cursor = listOf("updated_at") - - // Create schema with CDC deletion column - val schemaWithCdc = - ObjectType( - properties = - linkedMapOf( - "id" to FieldType(StringType, nullable = false), - "name" to FieldType(StringType, nullable = true), - "updated_at" to FieldType(TimestampTypeWithTimezone, nullable = true), - CDC_DELETED_AT_COLUMN to - FieldType(TimestampTypeWithTimezone, nullable = true) - ) - ) - - val stream = - mockk { - every { importType } returns - Dedupe( - primaryKey = primaryKey, - cursor = cursor, - ) - every { schema } returns schemaWithCdc - } - - val columnNameMapping = - ColumnNameMapping( - mapOf( - "id" to "id", - "name" to "name", - "updated_at" to "updated_at", - CDC_DELETED_AT_COLUMN to "_ab_cdc_deleted_at" - ) - ) - val sourceTableName = TableName(namespace = "test_ns", name = "source") - val targetTableName = TableName(namespace = "test_ns", name = "target") - - every { columnUtils.getFormattedColumnNames(any(), columnNameMapping, any()) } returns + fun testGenerateCopyFromStageWithRawModeColumns() { + val tableName = TableName(namespace = "namespace", name = "name") + val targetTableName = snowflakeDirectLoadSqlGenerator.fullyQualifiedName(tableName) + val stagingTableName = snowflakeDirectLoadSqlGenerator.fullyQualifiedStageName(tableName) + // Test with lowercase column names (raw mode) + val rawColumns = listOf( - "id", - "name", - "updated_at", - CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName(), - ) - .map { it.quote() } + DEFAULT_COLUMNS.map { it.columnName.quote() } - every { columnUtils.formatColumnName(any()) } answers { firstArg() } - + "_airbyte_raw_id", + "_airbyte_extracted_at", + "_airbyte_meta", + "_airbyte_generation_id", + "_airbyte_loaded_at", + "_airbyte_data" + ) val sql = - softDeleteGenerator.upsertTable( - stream = stream, - columnNameMapping = columnNameMapping, - sourceTableName = sourceTableName, - targetTableName = targetTableName, - ) - - // Should NOT include DELETE clause in soft delete mode - assert(!sql.contains("THEN DELETE")) - assert(sql.contains("WHEN NOT MATCHED THEN INSERT")) - assert(!sql.contains("AND new_record.${CDC_DELETED_AT_COLUMN.quote()} IS NULL")) - } - - @Test - fun testGenerateUpsertTableWithoutCdcColumn() { - // Test that CDC deletion logic is NOT applied when column is absent - val primaryKey = listOf(listOf("id")) - val cursor = listOf("updated_at") - - // Schema without CDC deletion column - val schemaWithoutCdc = - ObjectType( - properties = - linkedMapOf( - "id" to FieldType(StringType, nullable = false), - "name" to FieldType(StringType, nullable = true), - "updated_at" to FieldType(TimestampTypeWithTimezone, nullable = true) - ) - ) - - val stream = - mockk { - every { importType } returns - Dedupe( - primaryKey = primaryKey, - cursor = cursor, - ) - every { schema } returns schemaWithoutCdc - } - - val columnNameMapping = - ColumnNameMapping(mapOf("id" to "id", "name" to "name", "updated_at" to "updated_at")) - val sourceTableName = TableName(namespace = "test_ns", name = "source") - val targetTableName = TableName(namespace = "test_ns", name = "target") - - every { columnUtils.getFormattedColumnNames(any(), columnNameMapping, any()) } returns - listOf( - "id", - "name", - "updated_at", - ) - .map { it.quote() } + DEFAULT_COLUMNS.map { it.columnName.quote() } - every { columnUtils.formatColumnName(any()) } answers { firstArg() } - - val sql = - snowflakeDirectLoadSqlGenerator.upsertTable( - stream = stream, - columnNameMapping = columnNameMapping, - sourceTableName = sourceTableName, - targetTableName = targetTableName, - ) - - // Should NOT include any CDC-related clauses - assertFalse(sql.contains("_ab_cdc_deleted_at")) - assertFalse(sql.contains("THEN DELETE")) - assertTrue(sql.contains("WHEN NOT MATCHED THEN INSERT")) - } - - @Test - fun testGenerateUpsertTableWithNoCursor() { - // Test upsert with no cursor field - val primaryKey = listOf(listOf("id")) - - val schemaWithoutCursor = - ObjectType( - properties = - linkedMapOf( - "id" to FieldType(StringType, nullable = false), - "name" to FieldType(StringType, nullable = true) - ) - ) - - val stream = - mockk { - every { importType } returns - Dedupe( - primaryKey = primaryKey, - cursor = emptyList(), // No cursor - ) - every { schema } returns schemaWithoutCursor - } - - val columnNameMapping = ColumnNameMapping(mapOf("id" to "id", "name" to "name")) - val sourceTableName = TableName(namespace = "test_ns", name = "source") - val targetTableName = TableName(namespace = "test_ns", name = "target") - - every { columnUtils.getFormattedColumnNames(any(), columnNameMapping, any()) } returns - listOf( - "id", - "name", - CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName(), - ) - .map { it.quote() } + DEFAULT_COLUMNS.map { it.columnName.quote() } - every { columnUtils.formatColumnName(any()) } answers { firstArg() } - - val sql = - snowflakeDirectLoadSqlGenerator.upsertTable( - stream = stream, - columnNameMapping = columnNameMapping, - sourceTableName = sourceTableName, - targetTableName = targetTableName, - ) - - // Should use only _airbyte_extracted_at for comparison when no cursor - assert( - sql.contains( - "target_table.${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName().quote()} < new_record.${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName().quote()}" - ) - ) - assert(!sql.contains("target_table.${QUOTE}cursor${QUOTE}")) // No cursor field reference - } - - @Test - fun testGenerateUpsertTableWithoutPrimaryKeyThrowsException() { - // Test that upsert without primary key throws an exception - val stream = - mockk { - every { importType } returns - Dedupe( - primaryKey = emptyList(), // No primary key - cursor = listOf("updated_at"), - ) - every { schema } returns StringType - } - - val columnNameMapping = ColumnNameMapping(emptyMap()) - val sourceTableName = TableName(namespace = "test_ns", name = "source") - val targetTableName = TableName(namespace = "test_ns", name = "target") - - val exception = - assertThrows(IllegalArgumentException::class.java) { - snowflakeDirectLoadSqlGenerator.upsertTable( - stream = stream, - columnNameMapping = columnNameMapping, - sourceTableName = sourceTableName, - targetTableName = targetTableName, - ) - } - - assertEquals("Cannot perform upsert without primary key", exception.message) + snowflakeDirectLoadSqlGenerator.copyFromStage(tableName, "test.csv.gz", rawColumns) + val expectedSql = + """ + |COPY INTO $targetTableName("_airbyte_raw_id", "_airbyte_extracted_at", "_airbyte_meta", "_airbyte_generation_id", "_airbyte_loaded_at", "_airbyte_data") + |FROM '@$stagingTableName' + |FILE_FORMAT = ( + | TYPE = 'CSV' + | COMPRESSION = GZIP + | FIELD_DELIMITER = '$CSV_FIELD_SEPARATOR' + | RECORD_DELIMITER = '$CSV_LINE_DELIMITER' + | FIELD_OPTIONALLY_ENCLOSED_BY = '"' + | TRIM_SPACE = TRUE + | ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE + | REPLACE_INVALID_CHARACTERS = TRUE + | ESCAPE = NONE + | ESCAPE_UNENCLOSED_FIELD = NONE + |) + |ON_ERROR = 'ABORT_STATEMENT' + |PURGE = TRUE + |files = ('test.csv.gz') + """.trimMargin() + assertEquals(expectedSql, sql) } @Test @@ -623,7 +364,7 @@ new_record."_AIRBYTE_GENERATION_ID" val targetTableName = TableName(namespace = "namespace", name = "target") val sql = snowflakeDirectLoadSqlGenerator.swapTableWith(sourceTableName, targetTableName) assertEquals( - "ALTER TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} SWAP WITH ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)}", + "ALTER TABLE ${snowflakeDirectLoadSqlGenerator.fullyQualifiedName(sourceTableName)} SWAP WITH ${snowflakeDirectLoadSqlGenerator.fullyQualifiedName(targetTableName)}", sql ) } @@ -712,49 +453,69 @@ new_record."_AIRBYTE_GENERATION_ID" @Test fun testCreateTableWithSQLInjectionAttemptInTableName() { - val tableName = - TableName(namespace = "namespace", name = "table$QUOTE; DROP TABLE users; --") - val stream = mockk(relaxed = true) - val columnNameMapping = mockk(relaxed = true) + val tableName = TableName(namespace = "namespace", name = "table'; DROP TABLE users; --") - every { columnUtils.columnsAndTypes(any(), columnNameMapping) } returns emptyList() + // Create a minimal table schema for testing + val tableSchema = + StreamTableSchema( + tableNames = TableNames(finalTableName = tableName, tempTableName = tableName), + columnSchema = + ColumnSchema( + inputToFinalColumnNames = emptyMap(), + finalSchema = emptyMap(), + inputSchema = emptyMap() + ), + importType = Append + ) - val sql = - snowflakeDirectLoadSqlGenerator.createTable(stream, tableName, columnNameMapping, false) - val expectedTableName = - "${snowflakeConfiguration.database.toSnowflakeCompatibleName().quote()}.${tableName.namespace.quote()}.${tableName.name.quote()}" + val sql = snowflakeDirectLoadSqlGenerator.createTable(tableName, tableSchema, false) + + // The SQL injection attempt should be properly escaped with quotes + val expectedTableName = snowflakeDirectLoadSqlGenerator.fullyQualifiedName(tableName) val expectedSql = """ CREATE TABLE $expectedTableName ( - + "_AIRBYTE_RAW_ID" VARCHAR NOT NULL, + "_AIRBYTE_EXTRACTED_AT" TIMESTAMP_TZ NOT NULL, + "_AIRBYTE_META" VARIANT NOT NULL, + "_AIRBYTE_GENERATION_ID" NUMBER ) - """.trimIndent() + """.trimIndent() - // The dangerous SQL characters should be sanitized to underscores assertEquals(expectedSql, sql) } @Test fun testCreateTableWithSQLInjectionAttemptInNamespace() { - val tableName = - TableName(namespace = "namespace$QUOTE; DROP SCHEMA test; --", name = "table") - val stream = mockk(relaxed = true) - val columnNameMapping = mockk(relaxed = true) + val tableName = TableName(namespace = "namespace'; DROP SCHEMA test; --", name = "table") - every { columnUtils.columnsAndTypes(any(), columnNameMapping) } returns emptyList() + // Create a minimal table schema for testing + val tableSchema = + StreamTableSchema( + tableNames = TableNames(finalTableName = tableName, tempTableName = tableName), + columnSchema = + ColumnSchema( + inputToFinalColumnNames = emptyMap(), + finalSchema = emptyMap(), + inputSchema = emptyMap() + ), + importType = Append + ) - val sql = - snowflakeDirectLoadSqlGenerator.createTable(stream, tableName, columnNameMapping, false) - val expectedTableName = - "${snowflakeConfiguration.database.toSnowflakeCompatibleName().quote()}.${tableName.namespace.quote()}.${tableName.name.quote()}" + val sql = snowflakeDirectLoadSqlGenerator.createTable(tableName, tableSchema, false) + + // The SQL injection attempt should be properly escaped with quotes + val expectedTableName = snowflakeDirectLoadSqlGenerator.fullyQualifiedName(tableName) val expectedSql = """ CREATE TABLE $expectedTableName ( - + "_AIRBYTE_RAW_ID" VARCHAR NOT NULL, + "_AIRBYTE_EXTRACTED_AT" TIMESTAMP_TZ NOT NULL, + "_AIRBYTE_META" VARIANT NOT NULL, + "_AIRBYTE_GENERATION_ID" NUMBER ) - """.trimIndent() + """.trimIndent() - // The dangerous SQL characters should be sanitized to underscores assertEquals(expectedSql, sql) } @@ -762,17 +523,494 @@ new_record."_AIRBYTE_GENERATION_ID" fun testCreateTableWithReservedKeywordsAsNames() { // Test with Snowflake reserved keywords as table/namespace names val tableName = TableName(namespace = "SELECT", name = "WHERE") - val stream = mockk(relaxed = true) - val columnNameMapping = mockk(relaxed = true) - every { columnUtils.columnsAndTypes(any(), columnNameMapping) } returns emptyList() + // Create a minimal table schema for testing + val tableSchema = + StreamTableSchema( + tableNames = TableNames(finalTableName = tableName, tempTableName = tableName), + columnSchema = + ColumnSchema( + inputToFinalColumnNames = emptyMap(), + finalSchema = emptyMap(), + inputSchema = emptyMap() + ), + importType = Append + ) - val sql = - snowflakeDirectLoadSqlGenerator.createTable(stream, tableName, columnNameMapping, false) - val expectedTableName = - "${snowflakeConfiguration.database.toSnowflakeCompatibleName().quote()}.${tableName.namespace.quote()}.${tableName.name.quote()}" + val sql = snowflakeDirectLoadSqlGenerator.createTable(tableName, tableSchema, false) // Reserved keywords should be properly quoted - assertEquals("CREATE TABLE $expectedTableName (\n \n)", sql) + val expectedTableName = snowflakeDirectLoadSqlGenerator.fullyQualifiedName(tableName) + val expectedSql = + """ + CREATE TABLE $expectedTableName ( + "_AIRBYTE_RAW_ID" VARCHAR NOT NULL, + "_AIRBYTE_EXTRACTED_AT" TIMESTAMP_TZ NOT NULL, + "_AIRBYTE_META" VARIANT NOT NULL, + "_AIRBYTE_GENERATION_ID" NUMBER + ) + """.trimIndent() + + assertEquals(expectedSql, sql) + } + + @Test + fun testGenerateUpsertTableWithCdcHardDelete() { + // Test with CDC hard delete mode when _ab_cdc_deleted_at column is present + every { snowflakeConfiguration.cdcDeletionMode } returns CdcDeletionMode.HARD_DELETE + + val sourceTableName = TableName(namespace = "test_ns", name = "source") + val targetTableName = TableName(namespace = "test_ns", name = "target") + + // Create table schema with CDC deletion column + val tableSchema = + StreamTableSchema( + tableNames = + TableNames(finalTableName = targetTableName, tempTableName = sourceTableName), + columnSchema = + ColumnSchema( + inputToFinalColumnNames = + mapOf( + "id" to "ID", + "name" to "NAME", + "updated_at" to "UPDATED_AT", + "_ab_cdc_deleted_at" to "_AB_CDC_DELETED_AT" + ), + finalSchema = + mapOf( + "ID" to ColumnType("VARCHAR", false), + "NAME" to ColumnType("VARCHAR", true), + "UPDATED_AT" to ColumnType("TIMESTAMP_TZ", true), + "_AB_CDC_DELETED_AT" to ColumnType("TIMESTAMP_TZ", true) + ), + inputSchema = + mapOf( + "id" to FieldType(StringType, nullable = false), + "name" to FieldType(StringType, nullable = true), + "updated_at" to FieldType(StringType, nullable = true), + "_ab_cdc_deleted_at" to FieldType(StringType, nullable = true) + ) + ), + importType = + Dedupe(primaryKey = listOf(listOf("id")), cursor = listOf("updated_at")) + ) + + val sql = + snowflakeDirectLoadSqlGenerator.upsertTable( + tableSchema, + sourceTableName, + targetTableName + ) + + // Should include the DELETE clause and skip insert clause + assert( + sql.contains( + "WHEN MATCHED AND new_record.${CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName().quote()} IS NOT NULL" + ) + ) + assert(sql.contains("THEN DELETE")) + assert( + sql.contains( + "WHEN NOT MATCHED AND new_record.${CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName().quote()} IS NULL THEN INSERT" + ) + ) + } + + @Test + fun testGenerateUpsertTableWithCdcSoftDelete() { + // Test with CDC soft delete mode - should NOT add delete clauses + every { snowflakeConfiguration.cdcDeletionMode } returns CdcDeletionMode.SOFT_DELETE + + val sourceTableName = TableName(namespace = "test_ns", name = "source") + val targetTableName = TableName(namespace = "test_ns", name = "target") + + // Create table schema with CDC deletion column + val tableSchema = + StreamTableSchema( + tableNames = + TableNames(finalTableName = targetTableName, tempTableName = sourceTableName), + columnSchema = + ColumnSchema( + inputToFinalColumnNames = + mapOf( + "id" to "ID", + "name" to "NAME", + "updated_at" to "UPDATED_AT", + "_ab_cdc_deleted_at" to "_AB_CDC_DELETED_AT" + ), + finalSchema = + mapOf( + "ID" to ColumnType("VARCHAR", false), + "NAME" to ColumnType("VARCHAR", true), + "UPDATED_AT" to ColumnType("TIMESTAMP_TZ", true), + "_AB_CDC_DELETED_AT" to ColumnType("TIMESTAMP_TZ", true) + ), + inputSchema = + mapOf( + "id" to FieldType(StringType, nullable = false), + "name" to FieldType(StringType, nullable = true), + "updated_at" to FieldType(StringType, nullable = true), + "_ab_cdc_deleted_at" to FieldType(StringType, nullable = true) + ) + ), + importType = + Dedupe(primaryKey = listOf(listOf("id")), cursor = listOf("updated_at")) + ) + + val sql = + snowflakeDirectLoadSqlGenerator.upsertTable( + tableSchema, + sourceTableName, + targetTableName + ) + + // Should NOT include DELETE clause in soft delete mode + assert(!sql.contains("THEN DELETE")) + assert(sql.contains("WHEN NOT MATCHED THEN INSERT")) + assert(!sql.contains("AND new_record.${CDC_DELETED_AT_COLUMN.quote()} IS NULL")) + } + + @Test + fun testGenerateUpsertTableWithoutCdc() { + // Configure for no CDC + every { snowflakeConfiguration.cdcDeletionMode } returns CdcDeletionMode.SOFT_DELETE + + val sourceTableName = TableName(namespace = "namespace", name = "source") + val targetTableName = TableName(namespace = "namespace", name = "target") + + // Create a table schema with primary key and cursor but no CDC column + val tableSchema = + StreamTableSchema( + tableNames = + TableNames(finalTableName = targetTableName, tempTableName = sourceTableName), + columnSchema = + ColumnSchema( + inputToFinalColumnNames = mapOf("id" to "ID", "updated_at" to "UPDATED_AT"), + finalSchema = + mapOf( + "ID" to ColumnType("VARCHAR", false), + "UPDATED_AT" to ColumnType("TIMESTAMP_TZ", true) + ), + inputSchema = + mapOf( + "id" to FieldType(StringType, nullable = false), + "updated_at" to FieldType(StringType, nullable = true) + ) + ), + importType = + Dedupe(primaryKey = listOf(listOf("id")), cursor = listOf("updated_at")) + ) + + val sql = + snowflakeDirectLoadSqlGenerator.upsertTable( + tableSchema, + sourceTableName, + targetTableName + ) + + // Print the actual SQL for debugging + println("Generated SQL (without CDC):\n$sql") + + // Verify the SQL contains the expected components for non-CDC upsert + assertTrue(sql.contains("MERGE INTO")) + assertTrue(sql.contains("WITH records AS")) + assertTrue(sql.contains("numbered_rows AS")) + assertTrue(sql.contains("ROW_NUMBER() OVER")) + assertTrue(sql.contains("PARTITION BY \"ID\"")) + assertTrue(sql.contains("ORDER BY \"UPDATED_AT\" DESC NULLS LAST")) + // No CDC DELETE clause + assertFalse(sql.contains("_AB_CDC_DELETED_AT")) + assertTrue(sql.contains("WHEN MATCHED AND")) + assertTrue(sql.contains("THEN UPDATE SET")) + assertTrue(sql.contains("WHEN NOT MATCHED")) + assertTrue(sql.contains("THEN INSERT")) + } + + @Test + fun testGenerateUpsertTableNoCursor() { + val sourceTableName = TableName(namespace = "namespace", name = "source") + val targetTableName = TableName(namespace = "namespace", name = "target") + + // Create a table schema with primary key but no cursor + val tableSchema = + StreamTableSchema( + tableNames = + TableNames(finalTableName = targetTableName, tempTableName = sourceTableName), + columnSchema = + ColumnSchema( + inputToFinalColumnNames = mapOf("id" to "ID", "data" to "DATA"), + finalSchema = + mapOf( + "ID" to ColumnType("VARCHAR", false), + "DATA" to ColumnType("VARCHAR", true) + ), + inputSchema = + mapOf( + "id" to FieldType(StringType, nullable = false), + "data" to FieldType(StringType, nullable = true) + ) + ), + importType = Dedupe(primaryKey = listOf(listOf("id")), cursor = emptyList()) + ) + + val sql = + snowflakeDirectLoadSqlGenerator.upsertTable( + tableSchema, + sourceTableName, + targetTableName + ) + + // Print the actual SQL for debugging + println("Generated SQL (no cursor):\n$sql") + + // Verify the SQL contains the expected components for no-cursor upsert + assertTrue(sql.contains("MERGE INTO")) + assertTrue(sql.contains("WITH records AS")) + assertTrue(sql.contains("numbered_rows AS")) + assertTrue(sql.contains("ROW_NUMBER() OVER")) + assertTrue(sql.contains("PARTITION BY \"ID\"")) + // No cursor ordering, just _AIRBYTE_EXTRACTED_AT (may have extra space) + assertTrue(sql.contains("\"_AIRBYTE_EXTRACTED_AT\" DESC")) + assertFalse(sql.contains("UPDATED_AT")) + // Simple extracted_at comparison in WHEN MATCHED + assertTrue( + sql.contains( + "target_table.\"_AIRBYTE_EXTRACTED_AT\" < new_record.\"_AIRBYTE_EXTRACTED_AT\"" + ) + ) + assertTrue(sql.contains("THEN UPDATE SET")) + assertTrue(sql.contains("WHEN NOT MATCHED")) + assertTrue(sql.contains("THEN INSERT")) + } + + @Test + fun testGenerateOverwriteTable() { + // Test overwrite functionality using copyTable with empty destination + val sourceTableName = TableName(namespace = "namespace", name = "source") + val targetTableName = TableName(namespace = "namespace", name = "target") + + val columnNames = + setOf( + "_AIRBYTE_RAW_ID", + "_AIRBYTE_EXTRACTED_AT", + "_AIRBYTE_META", + "_AIRBYTE_GENERATION_ID", + "DATA_COL" + ) + + val sql = + snowflakeDirectLoadSqlGenerator.copyTable( + columnNames = columnNames, + sourceTableName = sourceTableName, + targetTableName = targetTableName + ) + + val expectedSourceTable = + snowflakeDirectLoadSqlGenerator.fullyQualifiedName(sourceTableName) + val expectedTargetTable = + snowflakeDirectLoadSqlGenerator.fullyQualifiedName(targetTableName) + + val columnList = columnNames.joinToString(", ") { "\"$it\"" } + val expectedSql = + """ + INSERT INTO $expectedTargetTable + ( + $columnList + ) + SELECT + $columnList + FROM $expectedSourceTable + """.trimIndent() + + assertEquals(expectedSql, sql) + } + + @Test + fun testGenerateUpdateTable() { + // Test update functionality is part of the MERGE statement + val sourceTableName = TableName(namespace = "namespace", name = "source") + val targetTableName = TableName(namespace = "namespace", name = "target") + + // Create a table schema with multiple columns to update + val tableSchema = + StreamTableSchema( + tableNames = + TableNames(finalTableName = targetTableName, tempTableName = sourceTableName), + columnSchema = + ColumnSchema( + inputToFinalColumnNames = + mapOf( + "id" to "ID", + "name" to "NAME", + "value" to "VALUE", + "updated_at" to "UPDATED_AT" + ), + finalSchema = + mapOf( + "ID" to ColumnType("VARCHAR", false), + "NAME" to ColumnType("VARCHAR", true), + "VALUE" to ColumnType("NUMBER", true), + "UPDATED_AT" to ColumnType("TIMESTAMP_TZ", true) + ), + inputSchema = + mapOf( + "id" to FieldType(StringType, nullable = false), + "name" to FieldType(StringType, nullable = true), + "value" to FieldType(StringType, nullable = true), + "updated_at" to FieldType(StringType, nullable = true) + ) + ), + importType = + Dedupe(primaryKey = listOf(listOf("id")), cursor = listOf("updated_at")) + ) + + val sql = + snowflakeDirectLoadSqlGenerator.upsertTable( + tableSchema, + sourceTableName, + targetTableName + ) + + // The MERGE statement should include UPDATE SET clause for all columns + val expectedSourceTable = + snowflakeDirectLoadSqlGenerator.fullyQualifiedName(sourceTableName) + val expectedTargetTable = + snowflakeDirectLoadSqlGenerator.fullyQualifiedName(targetTableName) + + val expectedSql = + """ + |MERGE INTO $expectedTargetTable AS target_table + |USING ( + | WITH records AS ( + | SELECT + | "_AIRBYTE_RAW_ID", + | "_AIRBYTE_EXTRACTED_AT", + | "_AIRBYTE_META", + | "_AIRBYTE_GENERATION_ID", + | "ID", + | "NAME", + | "VALUE", + | "UPDATED_AT" + | FROM $expectedSourceTable + | ), numbered_rows AS ( + | SELECT *, ROW_NUMBER() OVER ( + | PARTITION BY "ID" ORDER BY "UPDATED_AT" DESC NULLS LAST, "_AIRBYTE_EXTRACTED_AT" DESC + | ) AS row_number + | FROM records + | ) + | SELECT "_AIRBYTE_RAW_ID", + | "_AIRBYTE_EXTRACTED_AT", + | "_AIRBYTE_META", + | "_AIRBYTE_GENERATION_ID", + | "ID", + | "NAME", + | "VALUE", + | "UPDATED_AT" + | FROM numbered_rows + | WHERE row_number = 1 + |) AS new_record + |ON (target_table."ID" = new_record."ID" OR (target_table."ID" IS NULL AND new_record."ID" IS NULL)) + |WHEN MATCHED AND ( + | target_table."UPDATED_AT" < new_record."UPDATED_AT" + | OR (target_table."UPDATED_AT" = new_record."UPDATED_AT" AND target_table."_AIRBYTE_EXTRACTED_AT" < new_record."_AIRBYTE_EXTRACTED_AT") + | OR (target_table."UPDATED_AT" IS NULL AND new_record."UPDATED_AT" IS NULL AND target_table."_AIRBYTE_EXTRACTED_AT" < new_record."_AIRBYTE_EXTRACTED_AT") + | OR (target_table."UPDATED_AT" IS NULL AND new_record."UPDATED_AT" IS NOT NULL) + |) THEN UPDATE SET + | "_AIRBYTE_RAW_ID" = new_record."_AIRBYTE_RAW_ID", + | "_AIRBYTE_EXTRACTED_AT" = new_record."_AIRBYTE_EXTRACTED_AT", + | "_AIRBYTE_META" = new_record."_AIRBYTE_META", + | "_AIRBYTE_GENERATION_ID" = new_record."_AIRBYTE_GENERATION_ID", + | "ID" = new_record."ID", + | "NAME" = new_record."NAME", + | "VALUE" = new_record."VALUE", + | "UPDATED_AT" = new_record."UPDATED_AT" + |WHEN NOT MATCHED THEN INSERT ( + | "_AIRBYTE_RAW_ID", + | "_AIRBYTE_EXTRACTED_AT", + | "_AIRBYTE_META", + | "_AIRBYTE_GENERATION_ID", + | "ID", + | "NAME", + | "VALUE", + | "UPDATED_AT" + |) VALUES ( + | new_record."_AIRBYTE_RAW_ID", + | new_record."_AIRBYTE_EXTRACTED_AT", + | new_record."_AIRBYTE_META", + | new_record."_AIRBYTE_GENERATION_ID", + | new_record."ID", + | new_record."NAME", + | new_record."VALUE", + | new_record."UPDATED_AT" + |) + """.trimMargin() + + assertEquals(expectedSql, sql) + } + + // Tests moved from SnowflakeSqlNameUtilsTest + @Test + fun testFullyQualifiedNameInCountTable() { + val databaseName = "test-database" + val namespace = "test-namespace" + val name = "test=name" + val tableName = TableName(namespace = namespace, name = name) + every { snowflakeConfiguration.database } returns databaseName + + val sql = snowflakeDirectLoadSqlGenerator.countTable(tableName) + + val expected = + "SELECT COUNT(*) AS TOTAL FROM ${snowflakeDirectLoadSqlGenerator.fullyQualifiedName(tableName)}" + assertEquals(expected, sql) + } + + @Test + fun testFullyQualifiedNamespaceInCreateNamespace() { + val databaseName = "test-database" + val namespace = "test-namespace" + every { snowflakeConfiguration.database } returns databaseName + + val sql = snowflakeDirectLoadSqlGenerator.createNamespace(namespace) + + val expected = + "CREATE SCHEMA IF NOT EXISTS ${snowflakeDirectLoadSqlGenerator.fullyQualifiedNamespace(namespace)}" + assertEquals(expected, sql) + } + + @Test + fun testFullyQualifiedStageNameInCreateStage() { + val databaseName = "test-database" + val namespace = "test-namespace" + val name = "test=name" + val tableName = TableName(namespace = namespace, name = name) + every { snowflakeConfiguration.database } returns databaseName + + val sql = snowflakeDirectLoadSqlGenerator.createSnowflakeStage(tableName) + + val expected = + "CREATE STAGE IF NOT EXISTS ${snowflakeDirectLoadSqlGenerator.fullyQualifiedStageName(tableName)}" + assertEquals(expected, sql) + } + + @Test + fun testFullyQualifiedStageNameWithEscapeInPutInStage() { + val databaseName = "test-database" + val namespace = "test-namespace" + val name = "test=\"\"\'name" + val tableName = TableName(namespace = namespace, name = name) + every { snowflakeConfiguration.database } returns databaseName + + val sql = snowflakeDirectLoadSqlGenerator.putInStage(tableName, "/tmp/test.csv") + + val expectedStageName = + snowflakeDirectLoadSqlGenerator.fullyQualifiedStageName(tableName, true) + val expected = + """ + PUT 'file:///tmp/test.csv' '@$expectedStageName' + AUTO_COMPRESS = FALSE + SOURCE_COMPRESSION = GZIP + OVERWRITE = TRUE + """.trimIndent() + assertEquals(expected, sql) } } diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeSqlNameUtilsTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeSqlNameUtilsTest.kt deleted file mode 100644 index 843fd52a3ee..00000000000 --- a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/sql/SnowflakeSqlNameUtilsTest.kt +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright (c) 2025 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.integrations.destination.snowflake.sql - -import io.airbyte.cdk.load.table.TableName -import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName -import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration -import io.mockk.every -import io.mockk.mockk -import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Test - -internal class SnowflakeSqlNameUtilsTest { - - private lateinit var snowflakeConfiguration: SnowflakeConfiguration - private lateinit var snowflakeSqlNameUtils: SnowflakeSqlNameUtils - - @BeforeEach - fun setUp() { - snowflakeConfiguration = mockk(relaxed = true) - snowflakeSqlNameUtils = - SnowflakeSqlNameUtils(snowflakeConfiguration = snowflakeConfiguration) - } - - @Test - fun testFullyQualifiedName() { - val databaseName = "test-database" - val namespace = "test-namespace" - val name = "test=name" - val tableName = TableName(namespace = namespace, name = name) - every { snowflakeConfiguration.database } returns databaseName - - val expectedName = - snowflakeSqlNameUtils.combineParts( - listOf( - databaseName.toSnowflakeCompatibleName(), - tableName.namespace, - tableName.name - ) - ) - val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedName(tableName) - assertEquals(expectedName, fullyQualifiedName) - } - - @Test - fun testFullyQualifiedNamespace() { - val databaseName = "test-database" - val namespace = "test-namespace" - every { snowflakeConfiguration.database } returns databaseName - - val fullyQualifiedNamespace = snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace) - assertEquals("\"TEST-DATABASE\".\"test-namespace\"", fullyQualifiedNamespace) - } - - @Test - fun testFullyQualifiedStageName() { - val databaseName = "test-database" - val namespace = "test-namespace" - val name = "test=name" - val tableName = TableName(namespace = namespace, name = name) - every { snowflakeConfiguration.database } returns databaseName - - val expectedName = - snowflakeSqlNameUtils.combineParts( - listOf( - databaseName.toSnowflakeCompatibleName(), - namespace, - "$STAGE_NAME_PREFIX$name" - ) - ) - val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName) - assertEquals(expectedName, fullyQualifiedName) - } - - @Test - fun testFullyQualifiedStageNameWithEscape() { - val databaseName = "test-database" - val namespace = "test-namespace" - val name = "test=\"\"\'name" - val tableName = TableName(namespace = namespace, name = name) - every { snowflakeConfiguration.database } returns databaseName - - val expectedName = - snowflakeSqlNameUtils.combineParts( - listOf( - databaseName.toSnowflakeCompatibleName(), - namespace, - "$STAGE_NAME_PREFIX${sqlEscape(name)}" - ) - ) - val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true) - assertEquals(expectedName, fullyQualifiedName) - } -} diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeWriterTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeWriterTest.kt index c0435d0cda3..0a510398542 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeWriterTest.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/write/SnowflakeWriterTest.kt @@ -6,20 +6,21 @@ package io.airbyte.integrations.destination.snowflake.write import io.airbyte.cdk.SystemErrorException import io.airbyte.cdk.load.command.Append +import io.airbyte.cdk.load.command.DestinationCatalog import io.airbyte.cdk.load.command.DestinationStream -import io.airbyte.cdk.load.orchestration.db.DatabaseInitialStatusGatherer -import io.airbyte.cdk.load.orchestration.db.TableNames -import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator -import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadInitialStatus -import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendStreamLoader -import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendTruncateStreamLoader -import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableStatus -import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog -import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableNameInfo -import io.airbyte.cdk.load.table.ColumnNameMapping -import io.airbyte.cdk.load.table.TableName +import io.airbyte.cdk.load.schema.model.ColumnSchema +import io.airbyte.cdk.load.schema.model.StreamTableSchema +import io.airbyte.cdk.load.schema.model.TableName +import io.airbyte.cdk.load.schema.model.TableNames +import io.airbyte.cdk.load.table.DatabaseInitialStatusGatherer +import io.airbyte.cdk.load.table.TempTableNameGenerator +import io.airbyte.cdk.load.table.directload.DirectLoadInitialStatus +import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendStreamLoader +import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendTruncateStreamLoader +import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig +import io.airbyte.cdk.load.table.directload.DirectLoadTableStatus +import io.airbyte.cdk.load.write.StreamStateStore import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient -import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName import io.mockk.coEvery import io.mockk.coVerify import io.mockk.every @@ -34,55 +35,93 @@ internal class SnowflakeWriterTest { @Test fun testSetup() { val tableName = TableName(namespace = "test-namespace", name = "test-name") - val tableNames = TableNames(rawTableName = null, finalTableName = tableName) - val stream = mockk() - val tableInfo = - TableNameInfo( - tableNames = tableNames, - columnNameMapping = ColumnNameMapping(emptyMap()) - ) - val catalog = TableCatalog(mapOf(stream to tableInfo)) + val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp") + val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName) + val stream = + mockk { + every { tableSchema } returns + StreamTableSchema( + tableNames = tableNames, + columnSchema = + ColumnSchema( + inputToFinalColumnNames = emptyMap(), + finalSchema = emptyMap(), + inputSchema = emptyMap() + ), + importType = Append + ) + every { mappedDescriptor } returns + DestinationStream.Descriptor( + namespace = tableName.namespace, + name = tableName.name + ) + every { importType } returns Append + } + val catalog = DestinationCatalog(listOf(stream)) val snowflakeClient = mockk(relaxed = true) val stateGatherer = mockk> { - coEvery { gatherInitialStatus(catalog) } returns emptyMap() + coEvery { gatherInitialStatus() } returns + mapOf( + stream to + DirectLoadInitialStatus( + realTable = DirectLoadTableStatus(false), + tempTable = null + ) + ) } + val streamStateStore = mockk>() val writer = SnowflakeWriter( - names = catalog, + catalog = catalog, stateGatherer = stateGatherer, - streamStateStore = mockk(), + streamStateStore = streamStateStore, snowflakeClient = snowflakeClient, tempTableNameGenerator = mockk(), - snowflakeConfiguration = mockk(relaxed = true), + snowflakeConfiguration = + mockk(relaxed = true) { + every { internalTableSchema } returns "internal_schema" + }, ) runBlocking { writer.setup() } coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) } - coVerify(exactly = 1) { stateGatherer.gatherInitialStatus(catalog) } + coVerify(exactly = 1) { stateGatherer.gatherInitialStatus() } } @Test fun testCreateStreamLoaderFirstGeneration() { val tableName = TableName(namespace = "test-namespace", name = "test-name") - val tableNames = TableNames(rawTableName = null, finalTableName = tableName) + val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp") + val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName) val stream = mockk { every { minimumGenerationId } returns 0L every { generationId } returns 0L every { importType } returns Append + every { tableSchema } returns + StreamTableSchema( + tableNames = tableNames, + columnSchema = + ColumnSchema( + inputToFinalColumnNames = emptyMap(), + finalSchema = emptyMap(), + inputSchema = emptyMap() + ), + importType = Append + ) + every { mappedDescriptor } returns + DestinationStream.Descriptor( + namespace = tableName.namespace, + name = tableName.name + ) } - val tableInfo = - TableNameInfo( - tableNames = tableNames, - columnNameMapping = ColumnNameMapping(emptyMap()) - ) - val catalog = TableCatalog(mapOf(stream to tableInfo)) + val catalog = DestinationCatalog(listOf(stream)) val snowflakeClient = mockk(relaxed = true) val stateGatherer = mockk> { - coEvery { gatherInitialStatus(catalog) } returns + coEvery { gatherInitialStatus() } returns mapOf( stream to DirectLoadInitialStatus( @@ -93,14 +132,18 @@ internal class SnowflakeWriterTest { } val tempTableNameGenerator = mockk { every { generate(any()) } answers { firstArg() } } + val streamStateStore = mockk>() val writer = SnowflakeWriter( - names = catalog, + catalog = catalog, stateGatherer = stateGatherer, - streamStateStore = mockk(), + streamStateStore = streamStateStore, snowflakeClient = snowflakeClient, tempTableNameGenerator = tempTableNameGenerator, - snowflakeConfiguration = mockk(relaxed = true), + snowflakeConfiguration = + mockk(relaxed = true) { + every { internalTableSchema } returns "internal_schema" + }, ) runBlocking { @@ -113,23 +156,35 @@ internal class SnowflakeWriterTest { @Test fun testCreateStreamLoaderNotFirstGeneration() { val tableName = TableName(namespace = "test-namespace", name = "test-name") - val tableNames = TableNames(rawTableName = null, finalTableName = tableName) + val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp") + val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName) val stream = mockk { every { minimumGenerationId } returns 1L every { generationId } returns 1L every { importType } returns Append + every { tableSchema } returns + StreamTableSchema( + tableNames = tableNames, + columnSchema = + ColumnSchema( + inputToFinalColumnNames = emptyMap(), + finalSchema = emptyMap(), + inputSchema = emptyMap() + ), + importType = Append + ) + every { mappedDescriptor } returns + DestinationStream.Descriptor( + namespace = tableName.namespace, + name = tableName.name + ) } - val tableInfo = - TableNameInfo( - tableNames = tableNames, - columnNameMapping = ColumnNameMapping(emptyMap()) - ) - val catalog = TableCatalog(mapOf(stream to tableInfo)) + val catalog = DestinationCatalog(listOf(stream)) val snowflakeClient = mockk(relaxed = true) val stateGatherer = mockk> { - coEvery { gatherInitialStatus(catalog) } returns + coEvery { gatherInitialStatus() } returns mapOf( stream to DirectLoadInitialStatus( @@ -140,14 +195,18 @@ internal class SnowflakeWriterTest { } val tempTableNameGenerator = mockk { every { generate(any()) } answers { firstArg() } } + val streamStateStore = mockk>() val writer = SnowflakeWriter( - names = catalog, + catalog = catalog, stateGatherer = stateGatherer, - streamStateStore = mockk(), + streamStateStore = streamStateStore, snowflakeClient = snowflakeClient, tempTableNameGenerator = tempTableNameGenerator, - snowflakeConfiguration = mockk(relaxed = true), + snowflakeConfiguration = + mockk(relaxed = true) { + every { internalTableSchema } returns "internal_schema" + }, ) runBlocking { @@ -160,22 +219,35 @@ internal class SnowflakeWriterTest { @Test fun testCreateStreamLoaderHybrid() { val tableName = TableName(namespace = "test-namespace", name = "test-name") - val tableNames = TableNames(rawTableName = null, finalTableName = tableName) + val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp") + val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName) val stream = mockk { every { minimumGenerationId } returns 1L every { generationId } returns 2L + every { importType } returns Append + every { tableSchema } returns + StreamTableSchema( + tableNames = tableNames, + columnSchema = + ColumnSchema( + inputToFinalColumnNames = emptyMap(), + finalSchema = emptyMap(), + inputSchema = emptyMap() + ), + importType = Append + ) + every { mappedDescriptor } returns + DestinationStream.Descriptor( + namespace = tableName.namespace, + name = tableName.name + ) } - val tableInfo = - TableNameInfo( - tableNames = tableNames, - columnNameMapping = ColumnNameMapping(emptyMap()) - ) - val catalog = TableCatalog(mapOf(stream to tableInfo)) + val catalog = DestinationCatalog(listOf(stream)) val snowflakeClient = mockk(relaxed = true) val stateGatherer = mockk> { - coEvery { gatherInitialStatus(catalog) } returns + coEvery { gatherInitialStatus() } returns mapOf( stream to DirectLoadInitialStatus( @@ -186,14 +258,18 @@ internal class SnowflakeWriterTest { } val tempTableNameGenerator = mockk { every { generate(any()) } answers { firstArg() } } + val streamStateStore = mockk>() val writer = SnowflakeWriter( - names = catalog, + catalog = catalog, stateGatherer = stateGatherer, - streamStateStore = mockk(), + streamStateStore = streamStateStore, snowflakeClient = snowflakeClient, tempTableNameGenerator = tempTableNameGenerator, - snowflakeConfiguration = mockk(relaxed = true), + snowflakeConfiguration = + mockk(relaxed = true) { + every { internalTableSchema } returns "internal_schema" + }, ) runBlocking { @@ -203,169 +279,126 @@ internal class SnowflakeWriterTest { } @Test - fun testSetupWithNamespaceCreationFailure() { - val tableName = TableName(namespace = "test-namespace", name = "test-name") - val tableNames = TableNames(rawTableName = null, finalTableName = tableName) - val stream = mockk() - val tableInfo = - TableNameInfo( - tableNames = tableNames, - columnNameMapping = ColumnNameMapping(emptyMap()) - ) - val catalog = TableCatalog(mapOf(stream to tableInfo)) - val snowflakeClient = mockk() - val stateGatherer = mockk>() - val writer = - SnowflakeWriter( - names = catalog, - stateGatherer = stateGatherer, - streamStateStore = mockk(), - snowflakeClient = snowflakeClient, - tempTableNameGenerator = mockk(), - snowflakeConfiguration = mockk(), - ) - - // Simulate network failure during namespace creation - coEvery { - snowflakeClient.createNamespace(tableName.namespace.toSnowflakeCompatibleName()) - } throws RuntimeException("Network connection failed") - - assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } } - } - - @Test - fun testSetupWithInitialStatusGatheringFailure() { - val tableName = TableName(namespace = "test-namespace", name = "test-name") - val tableNames = TableNames(rawTableName = null, finalTableName = tableName) - val stream = mockk() - val tableInfo = - TableNameInfo( - tableNames = tableNames, - columnNameMapping = ColumnNameMapping(emptyMap()) - ) - val catalog = TableCatalog(mapOf(stream to tableInfo)) - val snowflakeClient = mockk(relaxed = true) - val stateGatherer = mockk>() - val writer = - SnowflakeWriter( - names = catalog, - stateGatherer = stateGatherer, - streamStateStore = mockk(), - snowflakeClient = snowflakeClient, - tempTableNameGenerator = mockk(), - snowflakeConfiguration = mockk(), - ) - - // Simulate failure while gathering initial status - coEvery { stateGatherer.gatherInitialStatus(catalog) } throws - RuntimeException("Failed to query table status") - - assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } } - - // Verify namespace creation was still attempted - coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) } - } - - @Test - fun testCreateStreamLoaderWithMissingInitialStatus() { - val tableName = TableName(namespace = "test-namespace", name = "test-name") - val tableNames = TableNames(rawTableName = null, finalTableName = tableName) + fun testCreateStreamLoaderNamespaceLegacy() { + val namespace = "test-namespace" + val name = "test-name" + val tableName = TableName(namespace = namespace, name = name) + val tempTableName = TableName(namespace = namespace, name = "${name}-temp") + val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName) val stream = mockk { every { minimumGenerationId } returns 0L every { generationId } returns 0L + every { importType } returns Append + every { tableSchema } returns + StreamTableSchema( + tableNames = tableNames, + columnSchema = + ColumnSchema( + inputToFinalColumnNames = emptyMap(), + finalSchema = emptyMap(), + inputSchema = emptyMap() + ), + importType = Append + ) + every { mappedDescriptor } returns + DestinationStream.Descriptor( + namespace = tableName.namespace, + name = tableName.name + ) } - val missingStream = - mockk { - every { minimumGenerationId } returns 0L - every { generationId } returns 0L - } - val tableInfo = - TableNameInfo( - tableNames = tableNames, - columnNameMapping = ColumnNameMapping(emptyMap()) - ) - val catalog = TableCatalog(mapOf(stream to tableInfo)) + val catalog = DestinationCatalog(listOf(stream)) val snowflakeClient = mockk(relaxed = true) val stateGatherer = mockk> { - coEvery { gatherInitialStatus(catalog) } returns + coEvery { gatherInitialStatus() } returns mapOf( stream to DirectLoadInitialStatus( realTable = DirectLoadTableStatus(false), - tempTable = null, + tempTable = null ) ) } + val streamStateStore = mockk>() val writer = SnowflakeWriter( - names = catalog, + catalog = catalog, stateGatherer = stateGatherer, - streamStateStore = mockk(), + streamStateStore = streamStateStore, snowflakeClient = snowflakeClient, tempTableNameGenerator = mockk(), - snowflakeConfiguration = mockk(relaxed = true), + snowflakeConfiguration = + mockk(relaxed = true) { + every { legacyRawTablesOnly } returns true + every { internalTableSchema } returns "internal_schema" + }, ) - runBlocking { - writer.setup() - // Try to create loader for a stream that wasn't in initial status - assertThrows(NullPointerException::class.java) { - writer.createStreamLoader(missingStream) + runBlocking { writer.setup() } + + coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) } + } + + @Test + fun testCreateStreamLoaderNamespaceNonLegacy() { + val namespace = "test-namespace" + val name = "test-name" + val tableName = TableName(namespace = namespace, name = name) + val tempTableName = TableName(namespace = namespace, name = "${name}-temp") + val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName) + val stream = + mockk { + every { minimumGenerationId } returns 0L + every { generationId } returns 0L + every { importType } returns Append + every { tableSchema } returns + StreamTableSchema( + tableNames = tableNames, + columnSchema = + ColumnSchema( + inputToFinalColumnNames = emptyMap(), + finalSchema = emptyMap(), + inputSchema = emptyMap() + ), + importType = Append + ) + every { mappedDescriptor } returns + DestinationStream.Descriptor( + namespace = tableName.namespace, + name = tableName.name + ) } - } - } - - @Test - fun testCreateStreamLoaderWithNullFinalTableName() { - // TableNames constructor throws IllegalStateException when both names are null - assertThrows(IllegalStateException::class.java) { - TableNames(rawTableName = null, finalTableName = null) - } - } - - @Test - fun testSetupWithMultipleNamespaceFailuresPartial() { - val tableName1 = TableName(namespace = "namespace1", name = "table1") - val tableName2 = TableName(namespace = "namespace2", name = "table2") - val tableNames1 = TableNames(rawTableName = null, finalTableName = tableName1) - val tableNames2 = TableNames(rawTableName = null, finalTableName = tableName2) - val stream1 = mockk() - val stream2 = mockk() - val tableInfo1 = - TableNameInfo( - tableNames = tableNames1, - columnNameMapping = ColumnNameMapping(emptyMap()) - ) - val tableInfo2 = - TableNameInfo( - tableNames = tableNames2, - columnNameMapping = ColumnNameMapping(emptyMap()) - ) - val catalog = TableCatalog(mapOf(stream1 to tableInfo1, stream2 to tableInfo2)) - val snowflakeClient = mockk() - val stateGatherer = mockk>() + val catalog = DestinationCatalog(listOf(stream)) + val snowflakeClient = mockk(relaxed = true) + val stateGatherer = + mockk> { + coEvery { gatherInitialStatus() } returns + mapOf( + stream to + DirectLoadInitialStatus( + realTable = DirectLoadTableStatus(false), + tempTable = null + ) + ) + } + val streamStateStore = mockk>() val writer = SnowflakeWriter( - names = catalog, + catalog = catalog, stateGatherer = stateGatherer, - streamStateStore = mockk(), + streamStateStore = streamStateStore, snowflakeClient = snowflakeClient, tempTableNameGenerator = mockk(), - snowflakeConfiguration = mockk(), + snowflakeConfiguration = + mockk(relaxed = true) { + every { legacyRawTablesOnly } returns false + every { internalTableSchema } returns "internal_schema" + }, ) - // First namespace succeeds, second fails (namespaces are uppercased by - // toSnowflakeCompatibleName) - coEvery { snowflakeClient.createNamespace("namespace1") } returns Unit - coEvery { snowflakeClient.createNamespace("namespace2") } throws - RuntimeException("Connection timeout") + runBlocking { writer.setup() } - assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } } - - // Verify both namespace creations were attempted - coVerify(exactly = 1) { snowflakeClient.createNamespace("namespace1") } - coVerify(exactly = 1) { snowflakeClient.createNamespace("namespace2") } + coVerify(exactly = 1) { snowflakeClient.createNamespace(namespace) } } } diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeInsertBufferTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeInsertBufferTest.kt index 47eda9daac4..63bd5fc6542 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeInsertBufferTest.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeInsertBufferTest.kt @@ -4,218 +4,220 @@ package io.airbyte.integrations.destination.snowflake.write.load +import io.airbyte.cdk.load.component.ColumnType import io.airbyte.cdk.load.data.AirbyteValue +import io.airbyte.cdk.load.data.FieldType import io.airbyte.cdk.load.data.IntegerValue import io.airbyte.cdk.load.data.NullValue +import io.airbyte.cdk.load.data.StringType import io.airbyte.cdk.load.data.StringValue import io.airbyte.cdk.load.message.Meta -import io.airbyte.cdk.load.table.TableName +import io.airbyte.cdk.load.schema.model.ColumnSchema +import io.airbyte.cdk.load.schema.model.TableName import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient +import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration -import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils import io.mockk.coVerify import io.mockk.every import io.mockk.mockk import java.io.BufferedReader -import java.io.File import java.io.InputStreamReader import java.util.zip.GZIPInputStream import kotlin.io.path.exists import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNotNull import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test internal class SnowflakeInsertBufferTest { private lateinit var snowflakeConfiguration: SnowflakeConfiguration - private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils + private lateinit var columnManager: SnowflakeColumnManager + private lateinit var columnSchema: ColumnSchema + private lateinit var snowflakeRecordFormatter: SnowflakeRecordFormatter @BeforeEach fun setUp() { snowflakeConfiguration = mockk(relaxed = true) - snowflakeColumnUtils = mockk(relaxed = true) + snowflakeRecordFormatter = SnowflakeSchemaRecordFormatter() + columnManager = + mockk(relaxed = true) { + every { getMetaColumns() } returns + linkedMapOf( + "_AIRBYTE_RAW_ID" to ColumnType("VARCHAR", false), + "_AIRBYTE_EXTRACTED_AT" to ColumnType("TIMESTAMP_TZ", false), + "_AIRBYTE_META" to ColumnType("VARIANT", false), + "_AIRBYTE_GENERATION_ID" to ColumnType("NUMBER", true) + ) + every { getTableColumnNames(any()) } returns + listOf( + "_AIRBYTE_RAW_ID", + "_AIRBYTE_EXTRACTED_AT", + "_AIRBYTE_META", + "_AIRBYTE_GENERATION_ID", + "columnName" + ) + } } @Test fun testAccumulate() { - val tableName = mockk(relaxed = true) + val tableName = TableName(namespace = "test", name = "table") val column = "columnName" - val columns = linkedMapOf(column to "NUMBER(38,0)") + columnSchema = + ColumnSchema( + inputToFinalColumnNames = mapOf(column to column.uppercase()), + finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)), + inputSchema = mapOf(column to FieldType(StringType, nullable = true)) + ) val snowflakeAirbyteClient = mockk(relaxed = true) val record = createRecord(column) val buffer = SnowflakeInsertBuffer( tableName = tableName, - columns = columns, snowflakeClient = snowflakeAirbyteClient, snowflakeConfiguration = snowflakeConfiguration, - flushLimit = 1, - snowflakeColumnUtils = snowflakeColumnUtils, + columnSchema = columnSchema, + columnManager = columnManager, + snowflakeRecordFormatter = snowflakeRecordFormatter, ) - - buffer.accumulate(record) - - assertEquals(true, buffer.csvFilePath?.exists()) + assertEquals(0, buffer.recordCount) + runBlocking { buffer.accumulate(record) } assertEquals(1, buffer.recordCount) } @Test - fun testAccumulateRaw() { - val tableName = mockk(relaxed = true) + fun testFlushToStaging() { + val tableName = TableName(namespace = "test", name = "table") val column = "columnName" - val columns = linkedMapOf(column to "NUMBER(38,0)") + columnSchema = + ColumnSchema( + inputToFinalColumnNames = mapOf(column to column.uppercase()), + finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)), + inputSchema = mapOf(column to FieldType(StringType, nullable = true)) + ) val snowflakeAirbyteClient = mockk(relaxed = true) val record = createRecord(column) val buffer = SnowflakeInsertBuffer( tableName = tableName, - columns = columns, snowflakeClient = snowflakeAirbyteClient, snowflakeConfiguration = snowflakeConfiguration, + columnSchema = columnSchema, + columnManager = columnManager, + snowflakeRecordFormatter = snowflakeRecordFormatter, flushLimit = 1, - snowflakeColumnUtils = snowflakeColumnUtils, ) - - every { snowflakeConfiguration.legacyRawTablesOnly } returns true - - buffer.accumulate(record) - - assertEquals(true, buffer.csvFilePath?.exists()) - assertEquals(1, buffer.recordCount) - } - - @Test - fun testFlush() { - val tableName = mockk(relaxed = true) - val column = "columnName" - val columns = linkedMapOf(column to "NUMBER(38,0)") - val snowflakeAirbyteClient = mockk(relaxed = true) - val record = createRecord(column) - val buffer = - SnowflakeInsertBuffer( - tableName = tableName, - columns = columns, - snowflakeClient = snowflakeAirbyteClient, - snowflakeConfiguration = snowflakeConfiguration, - flushLimit = 1, - snowflakeColumnUtils = snowflakeColumnUtils, + val expectedColumnNames = + listOf( + "_AIRBYTE_RAW_ID", + "_AIRBYTE_EXTRACTED_AT", + "_AIRBYTE_META", + "_AIRBYTE_GENERATION_ID", + "columnName" ) - runBlocking { buffer.accumulate(record) buffer.flush() - } - - coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) } - coVerify(exactly = 1) { - snowflakeAirbyteClient.copyFromStage( - tableName, - match { it.endsWith("$CSV_FILE_EXTENSION$FILE_SUFFIX") } - ) + coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) } + coVerify(exactly = 1) { + snowflakeAirbyteClient.copyFromStage(tableName, any(), expectedColumnNames) + } } } @Test - fun testFlushRaw() { - val tableName = mockk(relaxed = true) + fun testFlushToNoStaging() { + val tableName = TableName(namespace = "test", name = "table") val column = "columnName" - val columns = linkedMapOf(column to "NUMBER(38,0)") + columnSchema = + ColumnSchema( + inputToFinalColumnNames = mapOf(column to column.uppercase()), + finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)), + inputSchema = mapOf(column to FieldType(StringType, nullable = true)) + ) + val snowflakeAirbyteClient = mockk(relaxed = true) + every { snowflakeConfiguration.legacyRawTablesOnly } returns true + val record = createRecord(column) + val buffer = + SnowflakeInsertBuffer( + tableName = tableName, + snowflakeClient = snowflakeAirbyteClient, + snowflakeConfiguration = snowflakeConfiguration, + columnSchema = columnSchema, + columnManager = columnManager, + snowflakeRecordFormatter = snowflakeRecordFormatter, + flushLimit = 1, + ) + val expectedColumnNames = + listOf( + "_AIRBYTE_RAW_ID", + "_AIRBYTE_EXTRACTED_AT", + "_AIRBYTE_META", + "_AIRBYTE_GENERATION_ID", + "columnName" + ) + runBlocking { + buffer.accumulate(record) + buffer.flush() + // In legacy raw mode, it still uses staging + coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) } + coVerify(exactly = 1) { + snowflakeAirbyteClient.copyFromStage(tableName, any(), expectedColumnNames) + } + } + } + + @Test + fun testFileCreation() { + val tableName = TableName(namespace = "test", name = "table") + val column = "columnName" + columnSchema = + ColumnSchema( + inputToFinalColumnNames = mapOf(column to column.uppercase()), + finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)), + inputSchema = mapOf(column to FieldType(StringType, nullable = true)) + ) val snowflakeAirbyteClient = mockk(relaxed = true) val record = createRecord(column) val buffer = SnowflakeInsertBuffer( tableName = tableName, - columns = columns, snowflakeClient = snowflakeAirbyteClient, snowflakeConfiguration = snowflakeConfiguration, + columnSchema = columnSchema, + columnManager = columnManager, + snowflakeRecordFormatter = snowflakeRecordFormatter, flushLimit = 1, - snowflakeColumnUtils = snowflakeColumnUtils, ) - - every { snowflakeConfiguration.legacyRawTablesOnly } returns true - runBlocking { buffer.accumulate(record) - buffer.flush() - } - - coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) } - coVerify(exactly = 1) { - snowflakeAirbyteClient.copyFromStage( - tableName, - match { it.endsWith("$CSV_FILE_EXTENSION$FILE_SUFFIX") } - ) - } - } - - @Test - fun testMissingFields() { - val tableName = mockk(relaxed = true) - val snowflakeAirbyteClient = mockk(relaxed = true) - val record = createRecord("COLUMN1") - val buffer = - SnowflakeInsertBuffer( - tableName = tableName, - columns = linkedMapOf("COLUMN1" to "NUMBER(38,0)", "COLUMN2" to "NUMBER(38,0)"), - snowflakeClient = snowflakeAirbyteClient, - snowflakeConfiguration = snowflakeConfiguration, - flushLimit = 1, - snowflakeColumnUtils = snowflakeColumnUtils, - ) - - runBlocking { - buffer.accumulate(record) - buffer.csvWriter?.flush() + // The csvFilePath is internal, we can access it for testing + val filepath = buffer.csvFilePath + assertNotNull(filepath) + val file = filepath!!.toFile() + assert(file.exists()) + // Close the writer to ensure all data is flushed buffer.csvWriter?.close() - assertEquals( - "test-value$CSV_FIELD_SEPARATOR$CSV_LINE_DELIMITER", - readFromCsvFile(buffer.csvFilePath!!.toFile()) - ) + val lines = mutableListOf() + GZIPInputStream(file.inputStream()).use { gzip -> + BufferedReader(InputStreamReader(gzip)).use { bufferedReader -> + bufferedReader.forEachLine { line -> lines.add(line) } + } + } + assertEquals(1, lines.size) + file.delete() } } - @Test - fun testMissingFieldsRaw() { - val tableName = mockk(relaxed = true) - val snowflakeAirbyteClient = mockk(relaxed = true) - val record = createRecord("COLUMN1") - val buffer = - SnowflakeInsertBuffer( - tableName = tableName, - columns = linkedMapOf("COLUMN1" to "NUMBER(38,0)", "COLUMN2" to "NUMBER(38,0)"), - snowflakeClient = snowflakeAirbyteClient, - snowflakeConfiguration = snowflakeConfiguration, - flushLimit = 1, - snowflakeColumnUtils = snowflakeColumnUtils, - ) - - every { snowflakeConfiguration.legacyRawTablesOnly } returns true - - runBlocking { - buffer.accumulate(record) - buffer.csvWriter?.flush() - buffer.csvWriter?.close() - assertEquals( - "test-value$CSV_FIELD_SEPARATOR$CSV_LINE_DELIMITER", - readFromCsvFile(buffer.csvFilePath!!.toFile()) - ) - } - } - - private fun readFromCsvFile(file: File) = - GZIPInputStream(file.inputStream()).use { input -> - val reader = BufferedReader(InputStreamReader(input)) - reader.readText() - } - - private fun createRecord(columnName: String) = - mapOf( - columnName to AirbyteValue.from("test-value"), - Meta.COLUMN_NAME_AB_EXTRACTED_AT to IntegerValue(System.currentTimeMillis()), - Meta.COLUMN_NAME_AB_RAW_ID to StringValue("raw-id"), - Meta.COLUMN_NAME_AB_GENERATION_ID to IntegerValue(1223), - Meta.COLUMN_NAME_AB_META to StringValue("{\"changes\":[],\"syncId\":43}"), - "${columnName}Null" to NullValue + private fun createRecord(column: String): Map { + return mapOf( + column to IntegerValue(value = 42), + Meta.COLUMN_NAME_AB_GENERATION_ID to NullValue, + Meta.COLUMN_NAME_AB_RAW_ID to StringValue("raw-id-1"), + Meta.COLUMN_NAME_AB_EXTRACTED_AT to IntegerValue(1234567890), + Meta.COLUMN_NAME_AB_META to StringValue("meta-data-foo"), ) + } } diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeRawRecordFormatterTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeRawRecordFormatterTest.kt index 7633c7d67fb..d4c54904db7 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeRawRecordFormatterTest.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeRawRecordFormatterTest.kt @@ -15,12 +15,8 @@ import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA -import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils -import io.mockk.every -import io.mockk.mockk -import kotlin.collections.plus +import io.airbyte.cdk.load.schema.model.ColumnSchema import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test private val AIRBYTE_COLUMN_TYPES_MAP = @@ -58,28 +54,16 @@ private fun createExpected( internal class SnowflakeRawRecordFormatterTest { - private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils - - @BeforeEach - fun setup() { - snowflakeColumnUtils = mockk { - every { getFormattedDefaultColumnNames(any()) } returns - AIRBYTE_COLUMN_TYPES_MAP.keys.toList() - } - } - @Test fun testFormatting() { val columnName = "test-column-name" val columnValue = "test-column-value" val columns = AIRBYTE_COLUMN_TYPES_MAP val record = createRecord(columnName = columnName, columnValue = columnValue) - val formatter = - SnowflakeRawRecordFormatter( - columns = AIRBYTE_COLUMN_TYPES_MAP, - snowflakeColumnUtils = snowflakeColumnUtils - ) - val formattedValue = formatter.format(record) + val formatter = SnowflakeRawRecordFormatter() + // RawRecordFormatter doesn't use columnSchema but still needs one per interface + val dummyColumnSchema = ColumnSchema(emptyMap(), emptyMap(), emptyMap()) + val formattedValue = formatter.format(record, dummyColumnSchema) val expectedValue = createExpected( record = record, @@ -93,33 +77,28 @@ internal class SnowflakeRawRecordFormatterTest { fun testFormattingMigratedFromPreviousVersion() { val columnName = "test-column-name" val columnValue = "test-column-value" - val columnsMap = - linkedMapOf( - COLUMN_NAME_AB_EXTRACTED_AT to "TIMESTAMP_TZ(9)", - COLUMN_NAME_AB_LOADED_AT to "TIMESTAMP_TZ(9)", - COLUMN_NAME_AB_META to "VARIANT", - COLUMN_NAME_DATA to "VARIANT", - COLUMN_NAME_AB_RAW_ID to "VARCHAR(16777216)", - COLUMN_NAME_AB_GENERATION_ID to "NUMBER(38,0)", - ) val record = createRecord(columnName = columnName, columnValue = columnValue) - val formatter = - SnowflakeRawRecordFormatter( - columns = columnsMap, - snowflakeColumnUtils = snowflakeColumnUtils - ) - val formattedValue = formatter.format(record) + val formatter = SnowflakeRawRecordFormatter() + // RawRecordFormatter doesn't use columnSchema but still needs one per interface + val dummyColumnSchema = ColumnSchema(emptyMap(), emptyMap(), emptyMap()) + val formattedValue = formatter.format(record, dummyColumnSchema) + + // The formatter outputs in a fixed order regardless of input column order: + // 1. AB_RAW_ID + // 2. AB_EXTRACTED_AT + // 3. AB_META + // 4. AB_GENERATION_ID + // 5. AB_LOADED_AT + // 6. DATA (JSON with remaining columns) val expectedValue = - createExpected( - record = record, - columns = columnsMap, - airbyteColumns = columnsMap.keys.toList(), - ) - .toMutableList() - expectedValue.add( - columnsMap.keys.indexOf(COLUMN_NAME_DATA), - "{\"$columnName\":\"$columnValue\"}" - ) + listOf( + record[COLUMN_NAME_AB_RAW_ID]!!.toCsvValue(), + record[COLUMN_NAME_AB_EXTRACTED_AT]!!.toCsvValue(), + record[COLUMN_NAME_AB_META]!!.toCsvValue(), + record[COLUMN_NAME_AB_GENERATION_ID]!!.toCsvValue(), + record[COLUMN_NAME_AB_LOADED_AT]!!.toCsvValue(), + "{\"$columnName\":\"$columnValue\"}" + ) assertEquals(expectedValue, formattedValue) } } diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeSchemaRecordFormatterTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeSchemaRecordFormatterTest.kt index fd21eb0049a..a30d30c01f1 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeSchemaRecordFormatterTest.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/write/load/SnowflakeSchemaRecordFormatterTest.kt @@ -4,66 +4,58 @@ package io.airbyte.integrations.destination.snowflake.write.load +import io.airbyte.cdk.load.component.ColumnType import io.airbyte.cdk.load.data.AirbyteValue +import io.airbyte.cdk.load.data.FieldType import io.airbyte.cdk.load.data.IntegerValue import io.airbyte.cdk.load.data.NullValue +import io.airbyte.cdk.load.data.StringType import io.airbyte.cdk.load.data.StringValue import io.airbyte.cdk.load.data.csv.toCsvValue import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID -import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName -import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils -import io.mockk.every -import io.mockk.mockk -import java.util.AbstractMap -import kotlin.collections.component1 -import kotlin.collections.component2 +import io.airbyte.cdk.load.schema.model.ColumnSchema +import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test -private val AIRBYTE_COLUMN_TYPES_MAP = - linkedMapOf( - COLUMN_NAME_AB_RAW_ID to "VARCHAR(16777216)", - COLUMN_NAME_AB_EXTRACTED_AT to "TIMESTAMP_TZ(9)", - COLUMN_NAME_AB_META to "VARIANT", - COLUMN_NAME_AB_GENERATION_ID to "NUMBER(38,0)", - ) - .mapKeys { it.key.toSnowflakeCompatibleName() } - internal class SnowflakeSchemaRecordFormatterTest { - private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils + private fun createColumnSchema(userColumns: Map): ColumnSchema { + val finalSchema = linkedMapOf() + val inputToFinalColumnNames = mutableMapOf() + val inputSchema = mutableMapOf() - @BeforeEach - fun setup() { - snowflakeColumnUtils = mockk { - every { getFormattedDefaultColumnNames(any()) } returns - AIRBYTE_COLUMN_TYPES_MAP.keys.toList() + // Add user columns + userColumns.forEach { (name, type) -> + val finalName = name.toSnowflakeCompatibleName() + finalSchema[finalName] = ColumnType(type, true) + inputToFinalColumnNames[name] = finalName + inputSchema[name] = FieldType(StringType, nullable = true) } + + return ColumnSchema( + inputToFinalColumnNames = inputToFinalColumnNames, + finalSchema = finalSchema, + inputSchema = inputSchema + ) } @Test fun testFormatting() { val columnName = "test-column-name" val columnValue = "test-column-value" - val columns = - (AIRBYTE_COLUMN_TYPES_MAP + linkedMapOf(columnName to "VARCHAR(16777216)")).mapKeys { - it.key.toSnowflakeCompatibleName() - } + val userColumns = mapOf(columnName to "VARCHAR(16777216)") + val columnSchema = createColumnSchema(userColumns) val record = createRecord(columnName, columnValue) - val formatter = - SnowflakeSchemaRecordFormatter( - columns = columns as LinkedHashMap, - snowflakeColumnUtils = snowflakeColumnUtils - ) - val formattedValue = formatter.format(record) + val formatter = SnowflakeSchemaRecordFormatter() + val formattedValue = formatter.format(record, columnSchema) val expectedValue = createExpected( record = record, - columns = columns, + columnSchema = columnSchema, ) assertEquals(expectedValue, formattedValue) } @@ -72,21 +64,15 @@ internal class SnowflakeSchemaRecordFormatterTest { fun testFormattingVariant() { val columnName = "test-column-name" val columnValue = "{\"test\": \"test-value\"}" - val columns = - (AIRBYTE_COLUMN_TYPES_MAP + linkedMapOf(columnName to "VARIANT")).mapKeys { - it.key.toSnowflakeCompatibleName() - } + val userColumns = mapOf(columnName to "VARIANT") + val columnSchema = createColumnSchema(userColumns) val record = createRecord(columnName, columnValue) - val formatter = - SnowflakeSchemaRecordFormatter( - columns = columns as LinkedHashMap, - snowflakeColumnUtils = snowflakeColumnUtils - ) - val formattedValue = formatter.format(record) + val formatter = SnowflakeSchemaRecordFormatter() + val formattedValue = formatter.format(record, columnSchema) val expectedValue = createExpected( record = record, - columns = columns, + columnSchema = columnSchema, ) assertEquals(expectedValue, formattedValue) } @@ -95,23 +81,16 @@ internal class SnowflakeSchemaRecordFormatterTest { fun testFormattingMissingColumn() { val columnName = "test-column-name" val columnValue = "test-column-value" - val columns = - AIRBYTE_COLUMN_TYPES_MAP + - linkedMapOf( - columnName to "VARCHAR(16777216)", - "missing-column" to "VARCHAR(16777216)" - ) + val userColumns = + mapOf(columnName to "VARCHAR(16777216)", "missing-column" to "VARCHAR(16777216)") + val columnSchema = createColumnSchema(userColumns) val record = createRecord(columnName, columnValue) - val formatter = - SnowflakeSchemaRecordFormatter( - columns = columns as LinkedHashMap, - snowflakeColumnUtils = snowflakeColumnUtils - ) - val formattedValue = formatter.format(record) + val formatter = SnowflakeSchemaRecordFormatter() + val formattedValue = formatter.format(record, columnSchema) val expectedValue = createExpected( record = record, - columns = columns, + columnSchema = columnSchema, filterMissing = false, ) assertEquals(expectedValue, formattedValue) @@ -128,16 +107,37 @@ internal class SnowflakeSchemaRecordFormatterTest { private fun createExpected( record: Map, - columns: Map, + columnSchema: ColumnSchema, filterMissing: Boolean = true, - ) = - record.entries - .associate { entry -> entry.key.toSnowflakeCompatibleName() to entry.value } - .map { entry -> AbstractMap.SimpleEntry(entry.key, entry.value.toCsvValue()) } - .sortedBy { entry -> - if (columns.keys.indexOf(entry.key) > -1) columns.keys.indexOf(entry.key) - else Int.MAX_VALUE + ): List { + val columns = columnSchema.finalSchema.keys.toList() + val result = mutableListOf() + + // Add meta columns first in the expected order + result.add(record[COLUMN_NAME_AB_RAW_ID]?.toCsvValue() ?: "") + result.add(record[COLUMN_NAME_AB_EXTRACTED_AT]?.toCsvValue() ?: "") + result.add(record[COLUMN_NAME_AB_META]?.toCsvValue() ?: "") + result.add(record[COLUMN_NAME_AB_GENERATION_ID]?.toCsvValue() ?: "") + + // Add user columns + val userColumns = + columns.filterNot { col -> + listOf( + COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName(), + COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName(), + COLUMN_NAME_AB_META.toSnowflakeCompatibleName(), + COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName() + ) + .contains(col) } - .filter { (k, _) -> if (filterMissing) columns.contains(k) else true } - .map { it.value } + + userColumns.forEach { columnName -> + val value = record[columnName] ?: if (!filterMissing) NullValue else null + if (value != null || !filterMissing) { + result.add(value?.toCsvValue() ?: "") + } + } + + return result + } } diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/write/transform/SnowflakeColumnNameMapperTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/write/transform/SnowflakeColumnNameMapperTest.kt deleted file mode 100644 index fc2f73d918e..00000000000 --- a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/write/transform/SnowflakeColumnNameMapperTest.kt +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2025 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.integrations.destination.snowflake.write.transform - -import io.airbyte.cdk.load.command.DestinationStream -import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog -import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration -import io.mockk.every -import io.mockk.mockk -import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.Test - -internal class SnowflakeColumnNameMapperTest { - - @Test - fun testGetMappedColumnName() { - val columnName = "tést-column-name" - val expectedName = "test-column-name" - val stream = mockk() - val tableCatalog = mockk() - val snowflakeConfiguration = mockk(relaxed = true) - - // Configure the mock to return the expected mapped column name - every { tableCatalog.getMappedColumnName(stream, columnName) } returns expectedName - - val mapper = SnowflakeColumnNameMapper(tableCatalog, snowflakeConfiguration) - val result = mapper.getMappedColumnName(stream = stream, columnName = columnName) - assertEquals(expectedName, result) - } - - @Test - fun testGetMappedColumnNameRawFormat() { - val columnName = "tést-column-name" - val stream = mockk() - val tableCatalog = mockk() - val snowflakeConfiguration = - mockk { every { legacyRawTablesOnly } returns true } - - val mapper = SnowflakeColumnNameMapper(tableCatalog, snowflakeConfiguration) - val result = mapper.getMappedColumnName(stream = stream, columnName = columnName) - assertEquals(columnName, result) - } -} diff --git a/docs/integrations/destinations/snowflake.md b/docs/integrations/destinations/snowflake.md index 55c1ed9475a..31faba4e711 100644 --- a/docs/integrations/destinations/snowflake.md +++ b/docs/integrations/destinations/snowflake.md @@ -260,6 +260,7 @@ desired namespace. | Version | Date | Pull Request | Subject | |:----------------|:-----------|:--------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| 4.0.32-rc.1 | 2025-12-18 | [70903](https://github.com/airbytehq/airbyte/pull/70903) | Upgrade to CDK 0.1.91; internal refactoring | | 4.0.31 | 2025-12-08 | [70442](https://github.com/airbytehq/airbyte/pull/70442) | Write VARIANT values correctly when underlying Airbyte type is a `union` | | 4.0.30 | 2025-11-24 | [69842](https://github.com/airbytehq/airbyte/pull/69842) | Update documentation about numeric value handling | | 4.0.29 | 2025-11-14 | [69342](https://github.com/airbytehq/airbyte/pull/69342) | Truncate NumberValues and IntegerValues with excessive precision instead of nullifying them |