1
0
mirror of synced 2025-12-19 18:14:56 -05:00

Snowflake maps its schema once at the start. (#70903)

This commit is contained in:
Ryan Br...
2025-12-18 15:14:01 -08:00
committed by GitHub
parent 226af71657
commit 272f243e44
47 changed files with 2047 additions and 2557 deletions

View File

@@ -1,3 +1,3 @@
testExecutionConcurrency=-1 testExecutionConcurrency=-1
cdkVersion=0.1.82 cdkVersion=0.1.91
JunitMethodExecutionTimeout=10m JunitMethodExecutionTimeout=10m

View File

@@ -6,7 +6,7 @@ data:
connectorSubtype: database connectorSubtype: database
connectorType: destination connectorType: destination
definitionId: 424892c4-daac-4491-b35d-c6688ba547ba definitionId: 424892c4-daac-4491-b35d-c6688ba547ba
dockerImageTag: 4.0.31 dockerImageTag: 4.0.32-rc.1
dockerRepository: airbyte/destination-snowflake dockerRepository: airbyte/destination-snowflake
documentationUrl: https://docs.airbyte.com/integrations/destinations/snowflake documentationUrl: https://docs.airbyte.com/integrations/destinations/snowflake
githubIssueLabel: destination-snowflake githubIssueLabel: destination-snowflake
@@ -31,6 +31,8 @@ data:
enabled: true enabled: true
releaseStage: generally_available releaseStage: generally_available
releases: releases:
rolloutConfiguration:
enableProgressiveRollout: true
breakingChanges: breakingChanges:
2.0.0: 2.0.0:
message: Remove GCS/S3 loading method support. message: Remove GCS/S3 loading method support.

View File

@@ -11,15 +11,18 @@ import io.airbyte.cdk.load.check.CheckOperationV2
import io.airbyte.cdk.load.check.DestinationCheckerV2 import io.airbyte.cdk.load.check.DestinationCheckerV2
import io.airbyte.cdk.load.config.DataChannelMedium import io.airbyte.cdk.load.config.DataChannelMedium
import io.airbyte.cdk.load.dataflow.config.AggregatePublishingConfig import io.airbyte.cdk.load.dataflow.config.AggregatePublishingConfig
import io.airbyte.cdk.load.orchestration.db.DefaultTempTableNameGenerator import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.cdk.output.OutputConsumer import io.airbyte.cdk.output.OutputConsumer
import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory
import io.airbyte.integrations.destination.snowflake.spec.UsernamePasswordAuthConfiguration import io.airbyte.integrations.destination.snowflake.spec.UsernamePasswordAuthConfiguration
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRawRecordFormatter
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter
import io.micronaut.context.annotation.Factory import io.micronaut.context.annotation.Factory
import io.micronaut.context.annotation.Primary import io.micronaut.context.annotation.Primary
import io.micronaut.context.annotation.Requires import io.micronaut.context.annotation.Requires
@@ -204,6 +207,17 @@ class SnowflakeBeanFactory {
outputConsumer: OutputConsumer, outputConsumer: OutputConsumer,
) = CheckOperationV2(destinationChecker, outputConsumer) ) = CheckOperationV2(destinationChecker, outputConsumer)
@Singleton
fun snowflakeRecordFormatter(
snowflakeConfiguration: SnowflakeConfiguration
): SnowflakeRecordFormatter {
return if (snowflakeConfiguration.legacyRawTablesOnly) {
SnowflakeRawRecordFormatter()
} else {
SnowflakeSchemaRecordFormatter()
}
}
@Singleton @Singleton
fun aggregatePublishingConfig(dataChannelMedium: DataChannelMedium): AggregatePublishingConfig { fun aggregatePublishingConfig(dataChannelMedium: DataChannelMedium): AggregatePublishingConfig {
// NOT speed mode // NOT speed mode

View File

@@ -13,13 +13,17 @@ import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.ObjectType import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.StringType import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.message.Meta import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.schema.model.TableNames
import io.airbyte.cdk.load.table.ColumnNameMapping import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TableName
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter
import jakarta.inject.Singleton import jakarta.inject.Singleton
import java.time.OffsetDateTime import java.time.OffsetDateTime
import java.util.UUID import java.util.UUID
@@ -31,7 +35,7 @@ internal const val CHECK_COLUMN_NAME = "test_key"
class SnowflakeChecker( class SnowflakeChecker(
private val snowflakeAirbyteClient: SnowflakeAirbyteClient, private val snowflakeAirbyteClient: SnowflakeAirbyteClient,
private val snowflakeConfiguration: SnowflakeConfiguration, private val snowflakeConfiguration: SnowflakeConfiguration,
private val snowflakeColumnUtils: SnowflakeColumnUtils, private val columnManager: SnowflakeColumnManager,
) : DestinationCheckerV2 { ) : DestinationCheckerV2 {
override fun check() { override fun check() {
@@ -46,11 +50,40 @@ class SnowflakeChecker(
Meta.AirbyteMetaFields.GENERATION_ID.fieldName to AirbyteValue.from(0), Meta.AirbyteMetaFields.GENERATION_ID.fieldName to AirbyteValue.from(0),
CHECK_COLUMN_NAME.toSnowflakeCompatibleName() to AirbyteValue.from("test-value") CHECK_COLUMN_NAME.toSnowflakeCompatibleName() to AirbyteValue.from("test-value")
) )
val outputSchema = snowflakeConfiguration.schema.toSnowflakeCompatibleName() val outputSchema =
if (snowflakeConfiguration.legacyRawTablesOnly) {
snowflakeConfiguration.schema
} else {
snowflakeConfiguration.schema.toSnowflakeCompatibleName()
}
val tableName = val tableName =
"_airbyte_connection_test_${ "_airbyte_connection_test_${
UUID.randomUUID().toString().replace("-".toRegex(), "")}".toSnowflakeCompatibleName() UUID.randomUUID().toString().replace("-".toRegex(), "")}".toSnowflakeCompatibleName()
val qualifiedTableName = TableName(namespace = outputSchema, name = tableName) val qualifiedTableName = TableName(namespace = outputSchema, name = tableName)
val tableSchema =
StreamTableSchema(
tableNames =
TableNames(
finalTableName = qualifiedTableName,
tempTableName = qualifiedTableName
),
columnSchema =
ColumnSchema(
inputToFinalColumnNames =
mapOf(
CHECK_COLUMN_NAME to CHECK_COLUMN_NAME.toSnowflakeCompatibleName()
),
finalSchema =
mapOf(
CHECK_COLUMN_NAME.toSnowflakeCompatibleName() to
io.airbyte.cdk.load.component.ColumnType("VARCHAR", false)
),
inputSchema =
mapOf(CHECK_COLUMN_NAME to FieldType(StringType, nullable = false))
),
importType = Append
)
val destinationStream = val destinationStream =
DestinationStream( DestinationStream(
unmappedNamespace = outputSchema, unmappedNamespace = outputSchema,
@@ -63,7 +96,8 @@ class SnowflakeChecker(
generationId = 0L, generationId = 0L,
minimumGenerationId = 0L, minimumGenerationId = 0L,
syncId = 0L, syncId = 0L,
namespaceMapper = NamespaceMapper() namespaceMapper = NamespaceMapper(),
tableSchema = tableSchema
) )
runBlocking { runBlocking {
try { try {
@@ -75,14 +109,14 @@ class SnowflakeChecker(
replace = true, replace = true,
) )
val columns = snowflakeAirbyteClient.describeTable(qualifiedTableName)
val snowflakeInsertBuffer = val snowflakeInsertBuffer =
SnowflakeInsertBuffer( SnowflakeInsertBuffer(
tableName = qualifiedTableName, tableName = qualifiedTableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient, snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration, snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils, columnSchema = tableSchema.columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = SnowflakeSchemaRecordFormatter(),
) )
snowflakeInsertBuffer.accumulate(data) snowflakeInsertBuffer.accumulate(data)

View File

@@ -13,18 +13,16 @@ import io.airbyte.cdk.load.component.TableColumns
import io.airbyte.cdk.load.component.TableOperationsClient import io.airbyte.cdk.load.component.TableOperationsClient
import io.airbyte.cdk.load.component.TableSchema import io.airbyte.cdk.load.component.TableSchema
import io.airbyte.cdk.load.component.TableSchemaEvolutionClient import io.airbyte.cdk.load.component.TableSchemaEvolutionClient
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.ColumnNameMapping import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.util.deserializeToNode import io.airbyte.cdk.load.util.deserializeToNode
import io.airbyte.integrations.destination.snowflake.db.ColumnDefinition import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS
import io.airbyte.integrations.destination.snowflake.sql.NOT_NULL
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
import io.airbyte.integrations.destination.snowflake.sql.andLog import io.airbyte.integrations.destination.snowflake.sql.andLog
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
import io.github.oshai.kotlinlogging.KotlinLogging import io.github.oshai.kotlinlogging.KotlinLogging
import jakarta.inject.Singleton import jakarta.inject.Singleton
import java.sql.ResultSet import java.sql.ResultSet
@@ -41,13 +39,10 @@ private val log = KotlinLogging.logger {}
class SnowflakeAirbyteClient( class SnowflakeAirbyteClient(
private val dataSource: DataSource, private val dataSource: DataSource,
private val sqlGenerator: SnowflakeDirectLoadSqlGenerator, private val sqlGenerator: SnowflakeDirectLoadSqlGenerator,
private val snowflakeColumnUtils: SnowflakeColumnUtils,
private val snowflakeConfiguration: SnowflakeConfiguration, private val snowflakeConfiguration: SnowflakeConfiguration,
private val columnManager: SnowflakeColumnManager,
) : TableOperationsClient, TableSchemaEvolutionClient { ) : TableOperationsClient, TableSchemaEvolutionClient {
private val airbyteColumnNames =
snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet()
override suspend fun countTable(tableName: TableName): Long? = override suspend fun countTable(tableName: TableName): Long? =
try { try {
dataSource.connection.use { connection -> dataSource.connection.use { connection ->
@@ -126,7 +121,7 @@ class SnowflakeAirbyteClient(
columnNameMapping: ColumnNameMapping, columnNameMapping: ColumnNameMapping,
replace: Boolean replace: Boolean
) { ) {
execute(sqlGenerator.createTable(stream, tableName, columnNameMapping, replace)) execute(sqlGenerator.createTable(tableName, stream.tableSchema, replace))
execute(sqlGenerator.createSnowflakeStage(tableName)) execute(sqlGenerator.createSnowflakeStage(tableName))
} }
@@ -163,7 +158,15 @@ class SnowflakeAirbyteClient(
sourceTableName: TableName, sourceTableName: TableName,
targetTableName: TableName targetTableName: TableName
) { ) {
execute(sqlGenerator.copyTable(columnNameMapping, sourceTableName, targetTableName)) // Get all column names from the mapping (both meta columns and user columns)
val columnNames = buildSet {
// Add Airbyte meta columns (using uppercase constants)
addAll(columnManager.getMetaColumnNames())
// Add user columns from mapping
addAll(columnNameMapping.values)
}
execute(sqlGenerator.copyTable(columnNames, sourceTableName, targetTableName))
} }
override suspend fun upsertTable( override suspend fun upsertTable(
@@ -172,9 +175,7 @@ class SnowflakeAirbyteClient(
sourceTableName: TableName, sourceTableName: TableName,
targetTableName: TableName targetTableName: TableName
) { ) {
execute( execute(sqlGenerator.upsertTable(stream.tableSchema, sourceTableName, targetTableName))
sqlGenerator.upsertTable(stream, columnNameMapping, sourceTableName, targetTableName)
)
} }
override suspend fun dropTable(tableName: TableName) { override suspend fun dropTable(tableName: TableName) {
@@ -206,7 +207,7 @@ class SnowflakeAirbyteClient(
stream: DestinationStream, stream: DestinationStream,
columnNameMapping: ColumnNameMapping columnNameMapping: ColumnNameMapping
): TableSchema { ): TableSchema {
return TableSchema(getColumnsFromStream(stream, columnNameMapping)) return TableSchema(stream.tableSchema.columnSchema.finalSchema)
} }
override suspend fun applyChangeset( override suspend fun applyChangeset(
@@ -253,7 +254,7 @@ class SnowflakeAirbyteClient(
val columnName = escapeJsonIdentifier(rs.getString("name")) val columnName = escapeJsonIdentifier(rs.getString("name"))
// Filter out airbyte columns // Filter out airbyte columns
if (airbyteColumnNames.contains(columnName)) { if (columnManager.getMetaColumnNames().contains(columnName)) {
continue continue
} }
val dataType = rs.getString("type").takeWhile { char -> char != '(' } val dataType = rs.getString("type").takeWhile { char -> char != '(' }
@@ -271,49 +272,6 @@ class SnowflakeAirbyteClient(
} }
} }
internal fun getColumnsFromStream(
stream: DestinationStream,
columnNameMapping: ColumnNameMapping
): Map<String, ColumnType> =
snowflakeColumnUtils
.columnsAndTypes(stream.schema.asColumns(), columnNameMapping)
.filter { column -> column.columnName !in airbyteColumnNames }
.associate { column ->
// columnsAndTypes returns types as either `FOO` or `FOO NOT NULL`.
// so check for that suffix.
val nullable = !column.columnType.endsWith(NOT_NULL)
val type =
column.columnType
.takeWhile { char ->
// This is to remove any precision parts of the dialect type
char != '('
}
.removeSuffix(NOT_NULL)
.trim()
column.columnName to ColumnType(type, nullable)
}
internal fun generateSchemaChanges(
columnsInDb: Set<ColumnDefinition>,
columnsInStream: Set<ColumnDefinition>
): Triple<Set<ColumnDefinition>, Set<ColumnDefinition>, Set<ColumnDefinition>> {
val addedColumns =
columnsInStream.filter { it.name !in columnsInDb.map { col -> col.name } }.toSet()
val deletedColumns =
columnsInDb.filter { it.name !in columnsInStream.map { col -> col.name } }.toSet()
val commonColumns =
columnsInStream.filter { it.name in columnsInDb.map { col -> col.name } }.toSet()
val modifiedColumns =
commonColumns
.filter {
val dbType = columnsInDb.find { column -> it.name == column.name }?.type
it.type != dbType
}
.toSet()
return Triple(addedColumns, deletedColumns, modifiedColumns)
}
override suspend fun getGenerationId(tableName: TableName): Long = override suspend fun getGenerationId(tableName: TableName): Long =
try { try {
dataSource.connection.use { connection -> dataSource.connection.use { connection ->
@@ -326,7 +284,7 @@ class SnowflakeAirbyteClient(
* format. In order to make sure these strings will match any column names * format. In order to make sure these strings will match any column names
* that we have formatted in-memory, re-apply the escaping. * that we have formatted in-memory, re-apply the escaping.
*/ */
resultSet.getLong(snowflakeColumnUtils.getGenerationIdColumnName()) resultSet.getLong(columnManager.getGenerationIdColumnName())
} else { } else {
log.warn { log.warn {
"No generation ID found for table ${tableName.toPrettyString()}, returning 0" "No generation ID found for table ${tableName.toPrettyString()}, returning 0"
@@ -351,8 +309,8 @@ class SnowflakeAirbyteClient(
execute(sqlGenerator.putInStage(tableName, tempFilePath)) execute(sqlGenerator.putInStage(tableName, tempFilePath))
} }
fun copyFromStage(tableName: TableName, filename: String) { fun copyFromStage(tableName: TableName, filename: String, columnNames: List<String>) {
execute(sqlGenerator.copyFromStage(tableName, filename)) execute(sqlGenerator.copyFromStage(tableName, filename, columnNames))
} }
fun describeTable(tableName: TableName): LinkedHashMap<String, String> = fun describeTable(tableName: TableName): LinkedHashMap<String, String> =

View File

@@ -4,47 +4,41 @@
package io.airbyte.integrations.destination.snowflake.dataflow package io.airbyte.integrations.destination.snowflake.dataflow
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.dataflow.aggregate.Aggregate import io.airbyte.cdk.load.dataflow.aggregate.Aggregate
import io.airbyte.cdk.load.dataflow.aggregate.AggregateFactory import io.airbyte.cdk.load.dataflow.aggregate.AggregateFactory
import io.airbyte.cdk.load.dataflow.aggregate.StoreKey import io.airbyte.cdk.load.dataflow.aggregate.StoreKey
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.write.StreamStateStore import io.airbyte.cdk.load.write.StreamStateStore
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
import io.micronaut.cache.annotation.CacheConfig import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
import io.micronaut.cache.annotation.Cacheable
import jakarta.inject.Singleton import jakarta.inject.Singleton
@Singleton @Singleton
@CacheConfig("table-columns") class SnowflakeAggregateFactory(
// class has to be open to make the cache stuff work
open class SnowflakeAggregateFactory(
private val snowflakeClient: SnowflakeAirbyteClient, private val snowflakeClient: SnowflakeAirbyteClient,
private val streamStateStore: StreamStateStore<DirectLoadTableExecutionConfig>, private val streamStateStore: StreamStateStore<DirectLoadTableExecutionConfig>,
private val snowflakeConfiguration: SnowflakeConfiguration, private val snowflakeConfiguration: SnowflakeConfiguration,
private val snowflakeColumnUtils: SnowflakeColumnUtils, private val catalog: DestinationCatalog,
private val columnManager: SnowflakeColumnManager,
private val snowflakeRecordFormatter: SnowflakeRecordFormatter,
) : AggregateFactory { ) : AggregateFactory {
override fun create(key: StoreKey): Aggregate { override fun create(key: StoreKey): Aggregate {
val stream = catalog.getStream(key)
val tableName = streamStateStore.get(key)!!.tableName val tableName = streamStateStore.get(key)!!.tableName
val buffer = val buffer =
SnowflakeInsertBuffer( SnowflakeInsertBuffer(
tableName = tableName, tableName = tableName,
columns = getTableColumns(tableName),
snowflakeClient = snowflakeClient, snowflakeClient = snowflakeClient,
snowflakeConfiguration = snowflakeConfiguration, snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils, columnSchema = stream.tableSchema.columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
) )
return SnowflakeAggregate(buffer = buffer) return SnowflakeAggregate(buffer = buffer)
} }
// We assume that a table isn't getting altered _during_ a sync.
// This allows us to only SHOW COLUMNS once per table per sync,
// rather than refetching it on every aggregate.
@Cacheable
// function has to be open to make caching work
internal open fun getTableColumns(tableName: TableName) =
snowflakeClient.describeTable(tableName)
} }

View File

@@ -1,13 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
/**
* Jdbc destination column definition representation
*
* @param name
* @param type
*/
data class ColumnDefinition(val name: String, val type: String)

View File

@@ -4,17 +4,17 @@
package io.airbyte.integrations.destination.snowflake.db package io.airbyte.integrations.destination.snowflake.db
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.component.TableOperationsClient import io.airbyte.cdk.load.component.TableOperationsClient
import io.airbyte.cdk.load.orchestration.db.BaseDirectLoadInitialStatusGatherer import io.airbyte.cdk.load.table.BaseDirectLoadInitialStatusGatherer
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator
import jakarta.inject.Singleton import jakarta.inject.Singleton
@Singleton @Singleton
class SnowflakeDirectLoadDatabaseInitialStatusGatherer( class SnowflakeDirectLoadDatabaseInitialStatusGatherer(
tableOperationsClient: TableOperationsClient, tableOperationsClient: TableOperationsClient,
tempTableNameGenerator: TempTableNameGenerator, catalog: DestinationCatalog,
) : ) :
BaseDirectLoadInitialStatusGatherer( BaseDirectLoadInitialStatusGatherer(
tableOperationsClient, tableOperationsClient,
tempTableNameGenerator, catalog,
) )

View File

@@ -1,95 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
import io.airbyte.cdk.ConfigErrorException
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.orchestration.db.ColumnNameGenerator
import io.airbyte.cdk.load.orchestration.db.FinalTableNameGenerator
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TypingDedupingUtil
import io.airbyte.cdk.load.table.TableName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.QUOTE
import jakarta.inject.Singleton
@Singleton
class SnowflakeFinalTableNameGenerator(private val config: SnowflakeConfiguration) :
FinalTableNameGenerator {
override fun getTableName(streamDescriptor: DestinationStream.Descriptor): TableName {
val namespace = streamDescriptor.namespace ?: config.schema
return if (!config.legacyRawTablesOnly) {
TableName(
namespace = namespace.toSnowflakeCompatibleName(),
name = streamDescriptor.name.toSnowflakeCompatibleName(),
)
} else {
TableName(
namespace = config.internalTableSchema,
name =
TypingDedupingUtil.concatenateRawTableName(
namespace = escapeJsonIdentifier(namespace),
name = escapeJsonIdentifier(streamDescriptor.name),
),
)
}
}
}
@Singleton
class SnowflakeColumnNameGenerator(private val config: SnowflakeConfiguration) :
ColumnNameGenerator {
override fun getColumnName(column: String): ColumnNameGenerator.ColumnName {
return if (!config.legacyRawTablesOnly) {
ColumnNameGenerator.ColumnName(
column.toSnowflakeCompatibleName(),
column.toSnowflakeCompatibleName(),
)
} else {
ColumnNameGenerator.ColumnName(
column,
column,
)
}
}
}
/**
* Escapes double-quotes in a JSON identifier by doubling them. This shit is legacy -- I don't know
* why this would be necessary but no harm in keeping it so I am keeping it.
*
* @return The escaped identifier.
*/
fun escapeJsonIdentifier(identifier: String): String {
// Note that we don't need to escape backslashes here!
// The only special character in an identifier is the double-quote, which needs to be
// doubled.
return identifier.replace(QUOTE, "$QUOTE$QUOTE")
}
/**
* Transforms a string to be compatible with Snowflake table and column names.
*
* @return The transformed string suitable for Snowflake identifiers.
*/
fun String.toSnowflakeCompatibleName(): String {
var identifier = this
// Handle empty strings
if (identifier.isEmpty()) {
throw ConfigErrorException("Empty string is invalid identifier")
}
// Snowflake scripting language does something weird when the `${` bigram shows up in the
// script so replace these with something else.
// For completeness, if we trigger this, also replace closing curly braces with underscores.
if (identifier.contains("\${")) {
identifier = identifier.replace("$", "_").replace("{", "_").replace("}", "_")
}
// Escape double quotes
identifier = escapeJsonIdentifier(identifier)
return identifier.uppercase()
}

View File

@@ -0,0 +1,116 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.schema
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_EXTRACTED_AT
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_GENERATION_ID
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_META
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_RAW_ID
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDataType
import jakarta.inject.Singleton
/**
* Manages column names and ordering for Snowflake tables based on whether legacy raw tables mode is
* enabled.
*
* TODO: We should add meta column munging and raw table support to the CDK, so this extra layer of
* management shouldn't be necessary.
*/
@Singleton
class SnowflakeColumnManager(
private val config: SnowflakeConfiguration,
) {
/**
* Get the list of column names for a table in the order they should appear in the CSV file and
* COPY INTO statement.
*
* Warning: MUST match the order defined in SnowflakeRecordFormatter
*
* @param columnSchema The schema containing column information (ignored in raw mode)
* @return List of column names in the correct order
*/
fun getTableColumnNames(columnSchema: ColumnSchema): List<String> {
return buildList {
addAll(getMetaColumnNames())
addAll(columnSchema.finalSchema.keys)
}
}
/**
* Get the list of Airbyte meta column names. In schema mode, these are uppercase. In raw mode,
* they are lowercase and included loaded_at
*
* @return Set of meta column names
*/
fun getMetaColumnNames(): Set<String> =
if (config.legacyRawTablesOnly) {
Constants.rawModeMetaColNames
} else {
Constants.schemaModeMetaColNames
}
/**
* Get the Airbyte meta columns as a map of column name to ColumnType. This provides both the
* column names and their types for table creation.
*
* @param columnSchema The user column schema (used to check for CDC columns)
* @return Map of meta column names to their types
*/
fun getMetaColumns(): LinkedHashMap<String, ColumnType> {
return if (config.legacyRawTablesOnly) {
Constants.rawModeMetaColumns
} else {
Constants.schemaModeMetaColumns
}
}
fun getGenerationIdColumnName(): String {
return if (config.legacyRawTablesOnly) {
Meta.COLUMN_NAME_AB_GENERATION_ID
} else {
SNOWFLAKE_AB_GENERATION_ID
}
}
object Constants {
val rawModeMetaColumns =
linkedMapOf(
Meta.COLUMN_NAME_AB_RAW_ID to ColumnType(SnowflakeDataType.VARCHAR.typeName, false),
Meta.COLUMN_NAME_AB_EXTRACTED_AT to
ColumnType(
SnowflakeDataType.TIMESTAMP_TZ.typeName,
false,
),
Meta.COLUMN_NAME_AB_META to ColumnType(SnowflakeDataType.VARIANT.typeName, false),
Meta.COLUMN_NAME_AB_GENERATION_ID to
ColumnType(
SnowflakeDataType.NUMBER.typeName,
true,
),
Meta.COLUMN_NAME_AB_LOADED_AT to
ColumnType(
SnowflakeDataType.TIMESTAMP_TZ.typeName,
true,
),
)
val schemaModeMetaColumns =
linkedMapOf(
SNOWFLAKE_AB_RAW_ID to ColumnType(SnowflakeDataType.VARCHAR.typeName, false),
SNOWFLAKE_AB_EXTRACTED_AT to
ColumnType(SnowflakeDataType.TIMESTAMP_TZ.typeName, false),
SNOWFLAKE_AB_META to ColumnType(SnowflakeDataType.VARIANT.typeName, false),
SNOWFLAKE_AB_GENERATION_ID to ColumnType(SnowflakeDataType.NUMBER.typeName, true),
)
val rawModeMetaColNames: Set<String> = rawModeMetaColumns.keys
val schemaModeMetaColNames: Set<String> = schemaModeMetaColumns.keys
}
}

View File

@@ -0,0 +1,34 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.schema
import io.airbyte.cdk.ConfigErrorException
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
/**
* Transforms a string to be compatible with Snowflake table and column names.
*
* @return The transformed string suitable for Snowflake identifiers.
*/
fun String.toSnowflakeCompatibleName(): String {
var identifier = this
// Handle empty strings
if (identifier.isEmpty()) {
throw ConfigErrorException("Empty string is invalid identifier")
}
// Snowflake scripting language does something weird when the `${` bigram shows up in the
// script so replace these with something else.
// For completeness, if we trigger this, also replace closing curly braces with underscores.
if (identifier.contains("\${")) {
identifier = identifier.replace("$", "_").replace("{", "_").replace("}", "_")
}
// Escape double quotes
identifier = escapeJsonIdentifier(identifier)
return identifier.uppercase()
}

View File

@@ -0,0 +1,121 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.schema
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.data.ArrayType
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
import io.airbyte.cdk.load.data.BooleanType
import io.airbyte.cdk.load.data.DateType
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerType
import io.airbyte.cdk.load.data.NumberType
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
import io.airbyte.cdk.load.data.UnionType
import io.airbyte.cdk.load.data.UnknownType
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.TableSchemaMapper
import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.cdk.load.table.TypingDedupingUtil
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDataType
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
import jakarta.inject.Singleton
@Singleton
class SnowflakeTableSchemaMapper(
private val config: SnowflakeConfiguration,
private val tempTableNameGenerator: TempTableNameGenerator,
) : TableSchemaMapper {
override fun toFinalTableName(desc: DestinationStream.Descriptor): TableName {
val namespace = desc.namespace ?: config.schema
return if (!config.legacyRawTablesOnly) {
TableName(
namespace = namespace.toSnowflakeCompatibleName(),
name = desc.name.toSnowflakeCompatibleName(),
)
} else {
TableName(
namespace = config.internalTableSchema,
name =
TypingDedupingUtil.concatenateRawTableName(
namespace = escapeJsonIdentifier(namespace),
name = escapeJsonIdentifier(desc.name),
),
)
}
}
override fun toTempTableName(tableName: TableName): TableName {
return tempTableNameGenerator.generate(tableName)
}
override fun toColumnName(name: String): String {
return if (!config.legacyRawTablesOnly) {
name.toSnowflakeCompatibleName()
} else {
// In legacy mode, column names are not transformed
name
}
}
override fun toColumnType(fieldType: FieldType): ColumnType {
val snowflakeType =
when (fieldType.type) {
// Simple types
BooleanType -> SnowflakeDataType.BOOLEAN.typeName
IntegerType -> SnowflakeDataType.NUMBER.typeName
NumberType -> SnowflakeDataType.FLOAT.typeName
StringType -> SnowflakeDataType.VARCHAR.typeName
// Temporal types
DateType -> SnowflakeDataType.DATE.typeName
TimeTypeWithTimezone -> SnowflakeDataType.VARCHAR.typeName
TimeTypeWithoutTimezone -> SnowflakeDataType.TIME.typeName
TimestampTypeWithTimezone -> SnowflakeDataType.TIMESTAMP_TZ.typeName
TimestampTypeWithoutTimezone -> SnowflakeDataType.TIMESTAMP_NTZ.typeName
// Semistructured types
is ArrayType,
ArrayTypeWithoutSchema -> SnowflakeDataType.ARRAY.typeName
is ObjectType,
ObjectTypeWithEmptySchema,
ObjectTypeWithoutSchema -> SnowflakeDataType.OBJECT.typeName
is UnionType -> SnowflakeDataType.VARIANT.typeName
is UnknownType -> SnowflakeDataType.VARIANT.typeName
}
return ColumnType(snowflakeType, fieldType.nullable)
}
override fun toFinalSchema(tableSchema: StreamTableSchema): StreamTableSchema {
if (!config.legacyRawTablesOnly) {
return tableSchema
}
return StreamTableSchema(
tableNames = tableSchema.tableNames,
columnSchema =
tableSchema.columnSchema.copy(
finalSchema =
mapOf(
Meta.COLUMN_NAME_DATA to
ColumnType(SnowflakeDataType.OBJECT.typeName, false)
)
),
importType = tableSchema.importType,
)
}
}

View File

@@ -1,201 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
import com.google.common.annotations.VisibleForTesting
import io.airbyte.cdk.load.data.AirbyteType
import io.airbyte.cdk.load.data.ArrayType
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
import io.airbyte.cdk.load.data.BooleanType
import io.airbyte.cdk.load.data.DateType
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerType
import io.airbyte.cdk.load.data.NumberType
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
import io.airbyte.cdk.load.data.UnionType
import io.airbyte.cdk.load.data.UnknownType
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.integrations.destination.snowflake.db.SnowflakeColumnNameGenerator
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import jakarta.inject.Singleton
import kotlin.collections.component1
import kotlin.collections.component2
import kotlin.collections.joinToString
import kotlin.collections.map
import kotlin.collections.plus
internal const val NOT_NULL = "NOT NULL"
internal val DEFAULT_COLUMNS =
listOf(
ColumnAndType(
columnName = COLUMN_NAME_AB_RAW_ID,
columnType = "${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL"
),
ColumnAndType(
columnName = COLUMN_NAME_AB_EXTRACTED_AT,
columnType = "${SnowflakeDataType.TIMESTAMP_TZ.typeName} $NOT_NULL"
),
ColumnAndType(
columnName = COLUMN_NAME_AB_META,
columnType = "${SnowflakeDataType.VARIANT.typeName} $NOT_NULL"
),
ColumnAndType(
columnName = COLUMN_NAME_AB_GENERATION_ID,
columnType = SnowflakeDataType.NUMBER.typeName
),
)
internal val RAW_DATA_COLUMN =
ColumnAndType(
columnName = COLUMN_NAME_DATA,
columnType = "${SnowflakeDataType.VARIANT.typeName} $NOT_NULL"
)
internal val RAW_COLUMNS =
listOf(
ColumnAndType(
columnName = COLUMN_NAME_AB_LOADED_AT,
columnType = SnowflakeDataType.TIMESTAMP_TZ.typeName
),
RAW_DATA_COLUMN
)
@Singleton
class SnowflakeColumnUtils(
private val snowflakeConfiguration: SnowflakeConfiguration,
private val snowflakeColumnNameGenerator: SnowflakeColumnNameGenerator,
) {
@VisibleForTesting
internal fun defaultColumns(): List<ColumnAndType> =
if (snowflakeConfiguration.legacyRawTablesOnly) {
DEFAULT_COLUMNS + RAW_COLUMNS
} else {
DEFAULT_COLUMNS
}
internal fun formattedDefaultColumns(): List<ColumnAndType> =
defaultColumns().map {
ColumnAndType(
columnName = formatColumnName(it.columnName, false),
columnType = it.columnType,
)
}
fun getGenerationIdColumnName(): String {
return if (snowflakeConfiguration.legacyRawTablesOnly) {
COLUMN_NAME_AB_GENERATION_ID
} else {
COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName()
}
}
fun getColumnNames(columnNameMapping: ColumnNameMapping): String =
if (snowflakeConfiguration.legacyRawTablesOnly) {
getFormattedDefaultColumnNames(true).joinToString(",")
} else {
(getFormattedDefaultColumnNames(true) +
columnNameMapping.map { (_, actualName) -> actualName.quote() })
.joinToString(",")
}
fun getFormattedDefaultColumnNames(quote: Boolean = false): List<String> =
defaultColumns().map { formatColumnName(it.columnName, quote) }
fun getFormattedColumnNames(
columns: Map<String, FieldType>,
columnNameMapping: ColumnNameMapping,
quote: Boolean = true,
): List<String> =
if (snowflakeConfiguration.legacyRawTablesOnly) {
getFormattedDefaultColumnNames(quote)
} else {
getFormattedDefaultColumnNames(quote) +
columns.map { (fieldName, _) ->
val columnName = columnNameMapping[fieldName] ?: fieldName
if (quote) columnName.quote() else columnName
}
}
fun columnsAndTypes(
columns: Map<String, FieldType>,
columnNameMapping: ColumnNameMapping
): List<ColumnAndType> =
if (snowflakeConfiguration.legacyRawTablesOnly) {
formattedDefaultColumns()
} else {
formattedDefaultColumns() +
columns.map { (fieldName, type) ->
val columnName = columnNameMapping[fieldName] ?: fieldName
val typeName = toDialectType(type.type)
ColumnAndType(
columnName = columnName,
columnType = if (type.nullable) typeName else "$typeName $NOT_NULL",
)
}
}
fun formatColumnName(
columnName: String,
quote: Boolean = true,
): String {
val formattedColumnName =
if (columnName == COLUMN_NAME_DATA) columnName
else snowflakeColumnNameGenerator.getColumnName(columnName).displayName
return if (quote) formattedColumnName.quote() else formattedColumnName
}
fun toDialectType(type: AirbyteType): String =
when (type) {
// Simple types
BooleanType -> SnowflakeDataType.BOOLEAN.typeName
IntegerType -> SnowflakeDataType.NUMBER.typeName
NumberType -> SnowflakeDataType.FLOAT.typeName
StringType -> SnowflakeDataType.VARCHAR.typeName
// Temporal types
DateType -> SnowflakeDataType.DATE.typeName
TimeTypeWithTimezone -> SnowflakeDataType.VARCHAR.typeName
TimeTypeWithoutTimezone -> SnowflakeDataType.TIME.typeName
TimestampTypeWithTimezone -> SnowflakeDataType.TIMESTAMP_TZ.typeName
TimestampTypeWithoutTimezone -> SnowflakeDataType.TIMESTAMP_NTZ.typeName
// Semistructured types
is ArrayType,
ArrayTypeWithoutSchema -> SnowflakeDataType.ARRAY.typeName
is ObjectType,
ObjectTypeWithEmptySchema,
ObjectTypeWithoutSchema -> SnowflakeDataType.OBJECT.typeName
is UnionType -> SnowflakeDataType.VARIANT.typeName
is UnknownType -> SnowflakeDataType.VARIANT.typeName
}
}
data class ColumnAndType(val columnName: String, val columnType: String) {
override fun toString(): String {
return "${columnName.quote()} $columnType"
}
}
/**
* Surrounds the string instance with double quotation marks (e.g. "some string" -> "\"some
* string\"").
*/
fun String.quote() = "$QUOTE$this$QUOTE"

View File

@@ -10,7 +10,7 @@ package io.airbyte.integrations.destination.snowflake.sql
*/ */
enum class SnowflakeDataType(val typeName: String) { enum class SnowflakeDataType(val typeName: String) {
// Numeric types // Numeric types
NUMBER("NUMBER(38,0)"), NUMBER("NUMBER"),
FLOAT("FLOAT"), FLOAT("FLOAT"),
// String & binary types // String & binary types

View File

@@ -4,16 +4,19 @@
package io.airbyte.integrations.destination.snowflake.sql package io.airbyte.integrations.destination.snowflake.sql
import io.airbyte.cdk.load.command.Dedupe import com.google.common.annotations.VisibleForTesting
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.component.ColumnType import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.component.ColumnTypeChange import io.airbyte.cdk.load.component.ColumnTypeChange
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.util.UUIDGenerator import io.airbyte.cdk.load.util.UUIDGenerator
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.write.load.CSV_FIELD_SEPARATOR import io.airbyte.integrations.destination.snowflake.write.load.CSV_FIELD_SEPARATOR
@@ -22,6 +25,14 @@ import io.github.oshai.kotlinlogging.KotlinLogging
import jakarta.inject.Singleton import jakarta.inject.Singleton
internal const val COUNT_TOTAL_ALIAS = "TOTAL" internal const val COUNT_TOTAL_ALIAS = "TOTAL"
internal const val NOT_NULL = "NOT NULL"
// Snowflake-compatible (uppercase) versions of the Airbyte meta column names
internal val SNOWFLAKE_AB_RAW_ID = COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName()
internal val SNOWFLAKE_AB_EXTRACTED_AT = COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()
internal val SNOWFLAKE_AB_META = COLUMN_NAME_AB_META.toSnowflakeCompatibleName()
internal val SNOWFLAKE_AB_GENERATION_ID = COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName()
internal val SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN = CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()
private val log = KotlinLogging.logger {} private val log = KotlinLogging.logger {}
@@ -36,80 +47,91 @@ fun String.andLog(): String {
@Singleton @Singleton
class SnowflakeDirectLoadSqlGenerator( class SnowflakeDirectLoadSqlGenerator(
private val columnUtils: SnowflakeColumnUtils,
private val uuidGenerator: UUIDGenerator, private val uuidGenerator: UUIDGenerator,
private val snowflakeConfiguration: SnowflakeConfiguration, private val config: SnowflakeConfiguration,
private val snowflakeSqlNameUtils: SnowflakeSqlNameUtils, private val columnManager: SnowflakeColumnManager,
) { ) {
fun countTable(tableName: TableName): String { fun countTable(tableName: TableName): String {
return "SELECT COUNT(*) AS $COUNT_TOTAL_ALIAS FROM ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog() return "SELECT COUNT(*) AS $COUNT_TOTAL_ALIAS FROM ${fullyQualifiedName(tableName)}".andLog()
} }
fun createNamespace(namespace: String): String { fun createNamespace(namespace: String): String {
return "CREATE SCHEMA IF NOT EXISTS ${snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)}".andLog() return "CREATE SCHEMA IF NOT EXISTS ${fullyQualifiedNamespace(namespace)}".andLog()
} }
fun createTable( fun createTable(
stream: DestinationStream,
tableName: TableName, tableName: TableName,
columnNameMapping: ColumnNameMapping, tableSchema: StreamTableSchema,
replace: Boolean replace: Boolean
): String { ): String {
val finalSchema = tableSchema.columnSchema.finalSchema
val metaColumns = columnManager.getMetaColumns()
// Build column declarations from the meta columns and user schema
val columnDeclarations = val columnDeclarations =
columnUtils buildList {
.columnsAndTypes(stream.schema.asColumns(), columnNameMapping) // Add Airbyte meta columns from the column manager
.joinToString(",\n") metaColumns.forEach { (columnName, columnType) ->
val nullability = if (columnType.nullable) "" else " NOT NULL"
add("${columnName.quote()} ${columnType.type}$nullability")
}
// Add user columns from the munged schema
finalSchema.forEach { (columnName, columnType) ->
val nullability = if (columnType.nullable) "" else " NOT NULL"
add("${columnName.quote()} ${columnType.type}$nullability")
}
}
.joinToString(",\n ")
// Snowflake supports CREATE OR REPLACE TABLE, which is simpler than drop+recreate // Snowflake supports CREATE OR REPLACE TABLE, which is simpler than drop+recreate
val createOrReplace = if (replace) "CREATE OR REPLACE" else "CREATE" val createOrReplace = if (replace) "CREATE OR REPLACE" else "CREATE"
val createTableStatement = val createTableStatement =
""" """
$createOrReplace TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)} ( |$createOrReplace TABLE ${fullyQualifiedName(tableName)} (
$columnDeclarations | $columnDeclarations
) |)
""".trimIndent() """.trimMargin() // Something was tripping up trimIndent so we opt for trimMargin
return createTableStatement.andLog() return createTableStatement.andLog()
} }
fun showColumns(tableName: TableName): String = fun showColumns(tableName: TableName): String =
"SHOW COLUMNS IN TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog() "SHOW COLUMNS IN TABLE ${fullyQualifiedName(tableName)}".andLog()
fun copyTable( fun copyTable(
columnNameMapping: ColumnNameMapping, columnNames: Set<String>,
sourceTableName: TableName, sourceTableName: TableName,
targetTableName: TableName targetTableName: TableName
): String { ): String {
val columnNames = columnUtils.getColumnNames(columnNameMapping) val columnList = columnNames.joinToString(", ") { it.quote() }
return """ return """
INSERT INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} INSERT INTO ${fullyQualifiedName(targetTableName)}
( (
$columnNames $columnList
) )
SELECT SELECT
$columnNames $columnList
FROM ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} FROM ${fullyQualifiedName(sourceTableName)}
""" """
.trimIndent() .trimIndent()
.andLog() .andLog()
} }
fun upsertTable( fun upsertTable(
stream: DestinationStream, tableSchema: StreamTableSchema,
columnNameMapping: ColumnNameMapping,
sourceTableName: TableName, sourceTableName: TableName,
targetTableName: TableName targetTableName: TableName
): String { ): String {
val importType = stream.importType as Dedupe val finalSchema = tableSchema.columnSchema.finalSchema
// Build primary key matching condition // Build primary key matching condition
val pks = tableSchema.getPrimaryKey().flatten()
val pkEquivalent = val pkEquivalent =
if (importType.primaryKey.isNotEmpty()) { if (pks.isNotEmpty()) {
importType.primaryKey.joinToString(" AND ") { fieldPath -> pks.joinToString(" AND ") { columnName ->
val fieldName = fieldPath.first()
val columnName = columnNameMapping[fieldName] ?: fieldName
val targetTableColumnName = "target_table.${columnName.quote()}" val targetTableColumnName = "target_table.${columnName.quote()}"
val newRecordColumnName = "new_record.${columnName.quote()}" val newRecordColumnName = "new_record.${columnName.quote()}"
"""($targetTableColumnName = $newRecordColumnName OR ($targetTableColumnName IS NULL AND $newRecordColumnName IS NULL))""" """($targetTableColumnName = $newRecordColumnName OR ($targetTableColumnName IS NULL AND $newRecordColumnName IS NULL))"""
@@ -120,80 +142,62 @@ class SnowflakeDirectLoadSqlGenerator(
} }
// Build column lists for INSERT and UPDATE // Build column lists for INSERT and UPDATE
val columnList: String = val allColumns = buildList {
columnUtils add(SNOWFLAKE_AB_RAW_ID)
.getFormattedColumnNames( add(SNOWFLAKE_AB_EXTRACTED_AT)
columns = stream.schema.asColumns(), add(SNOWFLAKE_AB_META)
columnNameMapping = columnNameMapping, add(SNOWFLAKE_AB_GENERATION_ID)
quote = false, addAll(finalSchema.keys)
) }
.joinToString(
",\n",
) {
it.quote()
}
val columnList: String = allColumns.joinToString(",\n ") { it.quote() }
val newRecordColumnList: String = val newRecordColumnList: String =
columnUtils allColumns.joinToString(",\n ") { "new_record.${it.quote()}" }
.getFormattedColumnNames(
columns = stream.schema.asColumns(),
columnNameMapping = columnNameMapping,
quote = false,
)
.joinToString(",\n") { "new_record.${it.quote()}" }
// Get deduped records from source // Get deduped records from source
val selectSourceRecords = selectDedupedRecords(stream, sourceTableName, columnNameMapping) val selectSourceRecords = selectDedupedRecords(tableSchema, sourceTableName)
// Build cursor comparison for determining which record is newer // Build cursor comparison for determining which record is newer
val cursorComparison: String val cursorComparison: String
if (importType.cursor.isNotEmpty()) { val cursor = tableSchema.getCursor().firstOrNull()
val cursorFieldName = importType.cursor.first() if (cursor != null) {
val cursor = (columnNameMapping[cursorFieldName] ?: cursorFieldName)
val targetTableCursor = "target_table.${cursor.quote()}" val targetTableCursor = "target_table.${cursor.quote()}"
val newRecordCursor = "new_record.${cursor.quote()}" val newRecordCursor = "new_record.${cursor.quote()}"
cursorComparison = cursorComparison =
""" """
( (
$targetTableCursor < $newRecordCursor $targetTableCursor < $newRecordCursor
OR ($targetTableCursor = $newRecordCursor AND target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}") OR ($targetTableCursor = $newRecordCursor AND target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT")
OR ($targetTableCursor IS NULL AND $newRecordCursor IS NULL AND target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}") OR ($targetTableCursor IS NULL AND $newRecordCursor IS NULL AND target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT")
OR ($targetTableCursor IS NULL AND $newRecordCursor IS $NOT_NULL) OR ($targetTableCursor IS NULL AND $newRecordCursor IS $NOT_NULL)
) )
""".trimIndent() """.trimIndent()
} else { } else {
// No cursor - use extraction timestamp only // No cursor - use extraction timestamp only
cursorComparison = cursorComparison =
"""target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}"""" """target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT""""
} }
// Build column assignments for UPDATE // Build column assignments for UPDATE
val columnAssignments: String = val columnAssignments: String =
columnUtils allColumns.joinToString(",\n ") { column ->
.getFormattedColumnNames( "${column.quote()} = new_record.${column.quote()}"
columns = stream.schema.asColumns(), }
columnNameMapping = columnNameMapping,
quote = false,
)
.joinToString(",\n") { column ->
"${column.quote()} = new_record.${column.quote()}"
}
// Handle CDC deletions based on mode // Handle CDC deletions based on mode
val cdcDeleteClause: String val cdcDeleteClause: String
val cdcSkipInsertClause: String val cdcSkipInsertClause: String
if ( if (
stream.schema.asColumns().containsKey(CDC_DELETED_AT_COLUMN) && finalSchema.containsKey(SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN) &&
snowflakeConfiguration.cdcDeletionMode == CdcDeletionMode.HARD_DELETE config.cdcDeletionMode == CdcDeletionMode.HARD_DELETE
) { ) {
// Execute CDC deletions if there's already a record // Execute CDC deletions if there's already a record
cdcDeleteClause = cdcDeleteClause =
"WHEN MATCHED AND new_record.\"${CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()}\" IS NOT NULL AND $cursorComparison THEN DELETE" "WHEN MATCHED AND new_record.\"${SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN}\" IS NOT NULL AND $cursorComparison THEN DELETE"
// And skip insertion entirely if there's no matching record. // And skip insertion entirely if there's no matching record.
// (This is possible if a single T+D batch contains both an insertion and deletion for // (This is possible if a single T+D batch contains both an insertion and deletion for
// the same PK) // the same PK)
cdcSkipInsertClause = cdcSkipInsertClause = "AND new_record.\"${SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN}\" IS NULL"
"AND new_record.\"${CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()}\" IS NULL"
} else { } else {
cdcDeleteClause = "" cdcDeleteClause = ""
cdcSkipInsertClause = "" cdcSkipInsertClause = ""
@@ -203,35 +207,35 @@ class SnowflakeDirectLoadSqlGenerator(
val mergeStatement = val mergeStatement =
if (cdcDeleteClause.isNotEmpty()) { if (cdcDeleteClause.isNotEmpty()) {
""" """
MERGE INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} AS target_table |MERGE INTO ${fullyQualifiedName(targetTableName)} AS target_table
USING ( |USING (
$selectSourceRecords |$selectSourceRecords
) AS new_record |) AS new_record
ON $pkEquivalent |ON $pkEquivalent
$cdcDeleteClause |$cdcDeleteClause
WHEN MATCHED AND $cursorComparison THEN UPDATE SET |WHEN MATCHED AND $cursorComparison THEN UPDATE SET
$columnAssignments | $columnAssignments
WHEN NOT MATCHED $cdcSkipInsertClause THEN INSERT ( |WHEN NOT MATCHED $cdcSkipInsertClause THEN INSERT (
$columnList | $columnList
) VALUES ( |) VALUES (
$newRecordColumnList | $newRecordColumnList
) |)
""".trimIndent() """.trimMargin()
} else { } else {
""" """
MERGE INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} AS target_table |MERGE INTO ${fullyQualifiedName(targetTableName)} AS target_table
USING ( |USING (
$selectSourceRecords |$selectSourceRecords
) AS new_record |) AS new_record
ON $pkEquivalent |ON $pkEquivalent
WHEN MATCHED AND $cursorComparison THEN UPDATE SET |WHEN MATCHED AND $cursorComparison THEN UPDATE SET
$columnAssignments | $columnAssignments
WHEN NOT MATCHED THEN INSERT ( |WHEN NOT MATCHED THEN INSERT (
$columnList | $columnList
) VALUES ( |) VALUES (
$newRecordColumnList | $newRecordColumnList
) |)
""".trimIndent() """.trimMargin()
} }
return mergeStatement.andLog() return mergeStatement.andLog()
@@ -242,75 +246,66 @@ class SnowflakeDirectLoadSqlGenerator(
* table. Uses ROW_NUMBER() window function to select the most recent record per primary key. * table. Uses ROW_NUMBER() window function to select the most recent record per primary key.
*/ */
private fun selectDedupedRecords( private fun selectDedupedRecords(
stream: DestinationStream, tableSchema: StreamTableSchema,
sourceTableName: TableName, sourceTableName: TableName
columnNameMapping: ColumnNameMapping
): String { ): String {
val columnList: String = val allColumns = buildList {
columnUtils add(SNOWFLAKE_AB_RAW_ID)
.getFormattedColumnNames( add(SNOWFLAKE_AB_EXTRACTED_AT)
columns = stream.schema.asColumns(), add(SNOWFLAKE_AB_META)
columnNameMapping = columnNameMapping, add(SNOWFLAKE_AB_GENERATION_ID)
quote = false, addAll(tableSchema.columnSchema.finalSchema.keys)
) }
.joinToString( val columnList: String = allColumns.joinToString(",\n ") { it.quote() }
",\n",
) {
it.quote()
}
val importType = stream.importType as Dedupe
// Build the primary key list for partitioning // Build the primary key list for partitioning
val pks = tableSchema.getPrimaryKey().flatten()
val pkList = val pkList =
if (importType.primaryKey.isNotEmpty()) { if (pks.isNotEmpty()) {
importType.primaryKey.joinToString(",") { fieldPath -> pks.joinToString(",") { it.quote() }
(columnNameMapping[fieldPath.first()] ?: fieldPath.first()).quote()
}
} else { } else {
// Should not happen as we check this earlier, but handle it defensively // Should not happen as we check this earlier, but handle it defensively
throw IllegalArgumentException("Cannot deduplicate without primary key") throw IllegalArgumentException("Cannot deduplicate without primary key")
} }
// Build cursor order clause for sorting within each partition // Build cursor order clause for sorting within each partition
val cursor = tableSchema.getCursor().firstOrNull()
val cursorOrderClause = val cursorOrderClause =
if (importType.cursor.isNotEmpty()) { if (cursor != null) {
val columnName = "${cursor.quote()} DESC NULLS LAST,"
(columnNameMapping[importType.cursor.first()] ?: importType.cursor.first())
.quote()
"$columnName DESC NULLS LAST,"
} else { } else {
"" ""
} }
return """ return """
WITH records AS ( | WITH records AS (
SELECT | SELECT
$columnList | $columnList
FROM ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} | FROM ${fullyQualifiedName(sourceTableName)}
), numbered_rows AS ( | ), numbered_rows AS (
SELECT *, ROW_NUMBER() OVER ( | SELECT *, ROW_NUMBER() OVER (
PARTITION BY $pkList ORDER BY $cursorOrderClause "${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" DESC | PARTITION BY $pkList ORDER BY $cursorOrderClause "$SNOWFLAKE_AB_EXTRACTED_AT" DESC
) AS row_number | ) AS row_number
FROM records | FROM records
) | )
SELECT $columnList | SELECT $columnList
FROM numbered_rows | FROM numbered_rows
WHERE row_number = 1 | WHERE row_number = 1
""" """
.trimIndent() .trimMargin()
.andLog() .andLog()
} }
fun dropTable(tableName: TableName): String { fun dropTable(tableName: TableName): String {
return "DROP TABLE IF EXISTS ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog() return "DROP TABLE IF EXISTS ${fullyQualifiedName(tableName)}".andLog()
} }
fun getGenerationId( fun getGenerationId(
tableName: TableName, tableName: TableName,
): String { ): String {
return """ return """
SELECT "${columnUtils.getGenerationIdColumnName()}" SELECT "${columnManager.getGenerationIdColumnName()}"
FROM ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)} FROM ${fullyQualifiedName(tableName)}
LIMIT 1 LIMIT 1
""" """
.trimIndent() .trimIndent()
@@ -318,12 +313,12 @@ class SnowflakeDirectLoadSqlGenerator(
} }
fun createSnowflakeStage(tableName: TableName): String { fun createSnowflakeStage(tableName: TableName): String {
val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName) val stageName = fullyQualifiedStageName(tableName)
return "CREATE STAGE IF NOT EXISTS $stageName".andLog() return "CREATE STAGE IF NOT EXISTS $stageName".andLog()
} }
fun putInStage(tableName: TableName, tempFilePath: String): String { fun putInStage(tableName: TableName, tempFilePath: String): String {
val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true) val stageName = fullyQualifiedStageName(tableName, true)
return """ return """
PUT 'file://$tempFilePath' '@$stageName' PUT 'file://$tempFilePath' '@$stageName'
AUTO_COMPRESS = FALSE AUTO_COMPRESS = FALSE
@@ -334,35 +329,45 @@ class SnowflakeDirectLoadSqlGenerator(
.andLog() .andLog()
} }
fun copyFromStage(tableName: TableName, filename: String): String { fun copyFromStage(
val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true) tableName: TableName,
filename: String,
columnNames: List<String>? = null
): String {
val stageName = fullyQualifiedStageName(tableName, true)
val columnList =
columnNames?.let { names -> "(${names.joinToString(", ") { it.quote() }})" } ?: ""
return """ return """
COPY INTO ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)} |COPY INTO ${fullyQualifiedName(tableName)}$columnList
FROM '@$stageName' |FROM '@$stageName'
FILE_FORMAT = ( |FILE_FORMAT = (
TYPE = 'CSV' | TYPE = 'CSV'
COMPRESSION = GZIP | COMPRESSION = GZIP
FIELD_DELIMITER = '$CSV_FIELD_SEPARATOR' | FIELD_DELIMITER = '$CSV_FIELD_SEPARATOR'
RECORD_DELIMITER = '$CSV_LINE_DELIMITER' | RECORD_DELIMITER = '$CSV_LINE_DELIMITER'
FIELD_OPTIONALLY_ENCLOSED_BY = '"' | FIELD_OPTIONALLY_ENCLOSED_BY = '"'
TRIM_SPACE = TRUE | TRIM_SPACE = TRUE
ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE | ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE
REPLACE_INVALID_CHARACTERS = TRUE | REPLACE_INVALID_CHARACTERS = TRUE
ESCAPE = NONE | ESCAPE = NONE
ESCAPE_UNENCLOSED_FIELD = NONE | ESCAPE_UNENCLOSED_FIELD = NONE
) |)
ON_ERROR = 'ABORT_STATEMENT' |ON_ERROR = 'ABORT_STATEMENT'
PURGE = TRUE |PURGE = TRUE
files = ('$filename') |files = ('$filename')
""" """
.trimIndent() .trimMargin()
.andLog() .andLog()
} }
fun swapTableWith(sourceTableName: TableName, targetTableName: TableName): String { fun swapTableWith(sourceTableName: TableName, targetTableName: TableName): String {
return """ return """
ALTER TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} SWAP WITH ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} ALTER TABLE ${fullyQualifiedName(sourceTableName)} SWAP WITH ${
fullyQualifiedName(
targetTableName,
)
}
""" """
.trimIndent() .trimIndent()
.andLog() .andLog()
@@ -372,7 +377,11 @@ class SnowflakeDirectLoadSqlGenerator(
// Snowflake RENAME TO only accepts the table name, not a fully qualified name // Snowflake RENAME TO only accepts the table name, not a fully qualified name
// The renamed table stays in the same schema // The renamed table stays in the same schema
return """ return """
ALTER TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} RENAME TO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} ALTER TABLE ${fullyQualifiedName(sourceTableName)} RENAME TO ${
fullyQualifiedName(
targetTableName,
)
}
""" """
.trimIndent() .trimIndent()
.andLog() .andLog()
@@ -382,7 +391,7 @@ class SnowflakeDirectLoadSqlGenerator(
schemaName: String, schemaName: String,
tableName: String, tableName: String,
): String = ): String =
"""DESCRIBE TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(TableName(schemaName, tableName))}""".andLog() """DESCRIBE TABLE ${fullyQualifiedName(TableName(schemaName, tableName))}""".andLog()
fun alterTable( fun alterTable(
tableName: TableName, tableName: TableName,
@@ -391,14 +400,14 @@ class SnowflakeDirectLoadSqlGenerator(
modifiedColumns: Map<String, ColumnTypeChange>, modifiedColumns: Map<String, ColumnTypeChange>,
): Set<String> { ): Set<String> {
val clauses = mutableSetOf<String>() val clauses = mutableSetOf<String>()
val prettyTableName = snowflakeSqlNameUtils.fullyQualifiedName(tableName) val prettyTableName = fullyQualifiedName(tableName)
addedColumns.forEach { (name, columnType) -> addedColumns.forEach { (name, columnType) ->
clauses.add( clauses.add(
// Note that we intentionally don't set NOT NULL. // Note that we intentionally don't set NOT NULL.
// We're adding a new column, and we don't know what constitutes a reasonable // We're adding a new column, and we don't know what constitutes a reasonable
// default value for preexisting records. // default value for preexisting records.
// So we add the column as nullable. // So we add the column as nullable.
"ALTER TABLE $prettyTableName ADD COLUMN ${name.quote()} ${columnType.type};".andLog() "ALTER TABLE $prettyTableName ADD COLUMN ${name.quote()} ${columnType.type};".andLog(),
) )
} }
deletedColumns.forEach { deletedColumns.forEach {
@@ -412,35 +421,34 @@ class SnowflakeDirectLoadSqlGenerator(
val tempColumn = "${name}_${uuidGenerator.v4()}" val tempColumn = "${name}_${uuidGenerator.v4()}"
clauses.add( clauses.add(
// As above: we add the column as nullable. // As above: we add the column as nullable.
"ALTER TABLE $prettyTableName ADD COLUMN ${tempColumn.quote()} ${typeChange.newType.type};".andLog() "ALTER TABLE $prettyTableName ADD COLUMN ${tempColumn.quote()} ${typeChange.newType.type};".andLog(),
) )
clauses.add( clauses.add(
"UPDATE $prettyTableName SET ${tempColumn.quote()} = CAST(${name.quote()} AS ${typeChange.newType.type});".andLog() "UPDATE $prettyTableName SET ${tempColumn.quote()} = CAST(${name.quote()} AS ${typeChange.newType.type});".andLog(),
) )
val backupColumn = "${tempColumn}_backup" val backupColumn = "${tempColumn}_backup"
clauses.add( clauses.add(
""" """
ALTER TABLE $prettyTableName ALTER TABLE $prettyTableName
RENAME COLUMN "$name" TO "$backupColumn"; RENAME COLUMN "$name" TO "$backupColumn";
""".trimIndent() """.trimIndent(),
) )
clauses.add( clauses.add(
""" """
ALTER TABLE $prettyTableName ALTER TABLE $prettyTableName
RENAME COLUMN "$tempColumn" TO "$name"; RENAME COLUMN "$tempColumn" TO "$name";
""".trimIndent() """.trimIndent(),
) )
clauses.add( clauses.add(
"ALTER TABLE $prettyTableName DROP COLUMN ${backupColumn.quote()};".andLog() "ALTER TABLE $prettyTableName DROP COLUMN ${backupColumn.quote()};".andLog(),
) )
} else if (!typeChange.originalType.nullable && typeChange.newType.nullable) { } else if (!typeChange.originalType.nullable && typeChange.newType.nullable) {
// If the type is unchanged, we can change a column from NOT NULL to nullable. // If the type is unchanged, we can change a column from NOT NULL to nullable.
// But we'll never do the reverse, because there's a decent chance that historical // But we'll never do the reverse, because there's a decent chance that historical
// records // records had null values.
// had null values.
// Users can always manually ALTER COLUMN ... SET NOT NULL if they want. // Users can always manually ALTER COLUMN ... SET NOT NULL if they want.
clauses.add( clauses.add(
"""ALTER TABLE $prettyTableName ALTER COLUMN "$name" DROP NOT NULL;""".andLog() """ALTER TABLE $prettyTableName ALTER COLUMN "$name" DROP NOT NULL;""".andLog(),
) )
} else { } else {
log.info { log.info {
@@ -450,4 +458,45 @@ class SnowflakeDirectLoadSqlGenerator(
} }
return clauses return clauses
} }
@VisibleForTesting
fun fullyQualifiedName(tableName: TableName): String =
combineParts(listOf(getDatabaseName(), tableName.namespace, tableName.name))
@VisibleForTesting
fun fullyQualifiedNamespace(namespace: String) =
combineParts(listOf(getDatabaseName(), namespace))
@VisibleForTesting
fun fullyQualifiedStageName(tableName: TableName, escape: Boolean = false): String {
val currentTableName =
if (escape) {
tableName.name
} else {
tableName.name
}
return combineParts(
parts =
listOf(
getDatabaseName(),
tableName.namespace,
"$STAGE_NAME_PREFIX$currentTableName",
),
escape = escape,
)
}
@VisibleForTesting
internal fun combineParts(parts: List<String>, escape: Boolean = false): String =
parts
.map { if (escape) sqlEscape(it) else it }
.joinToString(separator = ".") {
if (!it.startsWith(QUOTE)) {
"$QUOTE$it$QUOTE"
} else {
it
}
}
private fun getDatabaseName() = config.database.toSnowflakeCompatibleName()
} }

View File

@@ -0,0 +1,29 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
const val STAGE_NAME_PREFIX = "airbyte_stage_"
internal const val QUOTE: String = "\""
fun sqlEscape(part: String) = part.replace("\\", "\\\\").replace("'", "\\'").replace("\"", "\\\"")
/**
* Surrounds the string instance with double quotation marks (e.g. "some string" -> "\"some
* string\"").
*/
fun String.quote() = "$QUOTE$this$QUOTE"
/**
* Escapes double-quotes in a JSON identifier by doubling them. This is legacy -- I don't know why
* this would be necessary but no harm in keeping it, so I am keeping it.
*
* @return The escaped identifier.
*/
fun escapeJsonIdentifier(identifier: String): String {
// Note that we don't need to escape backslashes here!
// The only special character in an identifier is the double-quote, which needs to be
// doubled.
return identifier.replace(QUOTE, "$QUOTE$QUOTE")
}

View File

@@ -1,57 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
import io.airbyte.cdk.load.table.TableName
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import jakarta.inject.Singleton
const val STAGE_NAME_PREFIX = "airbyte_stage_"
internal const val QUOTE: String = "\""
fun sqlEscape(part: String) = part.replace("\\", "\\\\").replace("'", "\\'").replace("\"", "\\\"")
@Singleton
class SnowflakeSqlNameUtils(
private val snowflakeConfiguration: SnowflakeConfiguration,
) {
fun fullyQualifiedName(tableName: TableName): String =
combineParts(listOf(getDatabaseName(), tableName.namespace, tableName.name))
fun fullyQualifiedNamespace(namespace: String) =
combineParts(listOf(getDatabaseName(), namespace))
fun fullyQualifiedStageName(tableName: TableName, escape: Boolean = false): String {
val currentTableName =
if (escape) {
tableName.name
} else {
tableName.name
}
return combineParts(
parts =
listOf(
getDatabaseName(),
tableName.namespace,
"$STAGE_NAME_PREFIX$currentTableName"
),
escape = escape,
)
}
fun combineParts(parts: List<String>, escape: Boolean = false): String =
parts
.map { if (escape) sqlEscape(it) else it }
.joinToString(separator = ".") {
if (!it.startsWith(QUOTE)) {
"$QUOTE$it$QUOTE"
} else {
it
}
}
private fun getDatabaseName() = snowflakeConfiguration.database.toSnowflakeCompatibleName()
}

View File

@@ -6,38 +6,39 @@ package io.airbyte.integrations.destination.snowflake.write
import io.airbyte.cdk.SystemErrorException import io.airbyte.cdk.SystemErrorException
import io.airbyte.cdk.load.command.Dedupe import io.airbyte.cdk.load.command.Dedupe
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationStream import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.orchestration.db.DatabaseInitialStatusGatherer import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator import io.airbyte.cdk.load.table.DatabaseInitialStatusGatherer
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadInitialStatus import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendStreamLoader import io.airbyte.cdk.load.table.directload.DirectLoadInitialStatus
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendTruncateStreamLoader import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendStreamLoader
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableDedupStreamLoader import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendTruncateStreamLoader
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableDedupTruncateStreamLoader import io.airbyte.cdk.load.table.directload.DirectLoadTableDedupStreamLoader
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig import io.airbyte.cdk.load.table.directload.DirectLoadTableDedupTruncateStreamLoader
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.write.DestinationWriter import io.airbyte.cdk.load.write.DestinationWriter
import io.airbyte.cdk.load.write.StreamLoader import io.airbyte.cdk.load.write.StreamLoader
import io.airbyte.cdk.load.write.StreamStateStore import io.airbyte.cdk.load.write.StreamStateStore
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
import jakarta.inject.Singleton import jakarta.inject.Singleton
@Singleton @Singleton
class SnowflakeWriter( class SnowflakeWriter(
private val names: TableCatalog, private val catalog: DestinationCatalog,
private val stateGatherer: DatabaseInitialStatusGatherer<DirectLoadInitialStatus>, private val stateGatherer: DatabaseInitialStatusGatherer<DirectLoadInitialStatus>,
private val streamStateStore: StreamStateStore<DirectLoadTableExecutionConfig>, private val streamStateStore: StreamStateStore<DirectLoadTableExecutionConfig>,
private val snowflakeClient: SnowflakeAirbyteClient, private val snowflakeClient: SnowflakeAirbyteClient,
private val tempTableNameGenerator: TempTableNameGenerator,
private val snowflakeConfiguration: SnowflakeConfiguration, private val snowflakeConfiguration: SnowflakeConfiguration,
private val tempTableNameGenerator: TempTableNameGenerator,
) : DestinationWriter { ) : DestinationWriter {
private lateinit var initialStatuses: Map<DestinationStream, DirectLoadInitialStatus> private lateinit var initialStatuses: Map<DestinationStream, DirectLoadInitialStatus>
override suspend fun setup() { override suspend fun setup() {
names.values catalog.streams
.map { (tableNames, _) -> tableNames.finalTableName!!.namespace } .map { it.tableSchema.tableNames.finalTableName!!.namespace }
.toSet() .toSet()
.forEach { snowflakeClient.createNamespace(it) } .forEach { snowflakeClient.createNamespace(it) }
@@ -45,15 +46,15 @@ class SnowflakeWriter(
escapeJsonIdentifier(snowflakeConfiguration.internalTableSchema) escapeJsonIdentifier(snowflakeConfiguration.internalTableSchema)
) )
initialStatuses = stateGatherer.gatherInitialStatus(names) initialStatuses = stateGatherer.gatherInitialStatus()
} }
override fun createStreamLoader(stream: DestinationStream): StreamLoader { override fun createStreamLoader(stream: DestinationStream): StreamLoader {
val initialStatus = initialStatuses[stream]!! val initialStatus = initialStatuses[stream]!!
val tableNameInfo = names[stream]!! val realTableName = stream.tableSchema.tableNames.finalTableName!!
val realTableName = tableNameInfo.tableNames.finalTableName!! val tempTableName = stream.tableSchema.tableNames.tempTableName!!
val tempTableName = tempTableNameGenerator.generate(realTableName) val columnNameMapping =
val columnNameMapping = tableNameInfo.columnNameMapping ColumnNameMapping(stream.tableSchema.columnSchema.inputToFinalColumnNames)
return when (stream.minimumGenerationId) { return when (stream.minimumGenerationId) {
0L -> 0L ->
when (stream.importType) { when (stream.importType) {

View File

@@ -9,11 +9,12 @@ import de.siegmar.fastcsv.writer.CsvWriter
import de.siegmar.fastcsv.writer.LineDelimiter import de.siegmar.fastcsv.writer.LineDelimiter
import de.siegmar.fastcsv.writer.QuoteStrategies import de.siegmar.fastcsv.writer.QuoteStrategies
import io.airbyte.cdk.load.data.AirbyteValue import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.table.TableName import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.QUOTE import io.airbyte.integrations.destination.snowflake.sql.QUOTE
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.github.oshai.kotlinlogging.KotlinLogging import io.github.oshai.kotlinlogging.KotlinLogging
import java.io.File import java.io.File
import java.io.OutputStream import java.io.OutputStream
@@ -36,10 +37,11 @@ private const val CSV_WRITER_BUFFER_SIZE = 1024 * 1024 // 1 MB
class SnowflakeInsertBuffer( class SnowflakeInsertBuffer(
private val tableName: TableName, private val tableName: TableName,
val columns: LinkedHashMap<String, String>,
private val snowflakeClient: SnowflakeAirbyteClient, private val snowflakeClient: SnowflakeAirbyteClient,
val snowflakeConfiguration: SnowflakeConfiguration, val snowflakeConfiguration: SnowflakeConfiguration,
val snowflakeColumnUtils: SnowflakeColumnUtils, val columnSchema: ColumnSchema,
private val columnManager: SnowflakeColumnManager,
private val snowflakeRecordFormatter: SnowflakeRecordFormatter,
private val flushLimit: Int = DEFAULT_FLUSH_LIMIT, private val flushLimit: Int = DEFAULT_FLUSH_LIMIT,
) { ) {
@@ -57,12 +59,6 @@ class SnowflakeInsertBuffer(
.lineDelimiter(CSV_LINE_DELIMITER) .lineDelimiter(CSV_LINE_DELIMITER)
.quoteStrategy(QuoteStrategies.REQUIRED) .quoteStrategy(QuoteStrategies.REQUIRED)
private val snowflakeRecordFormatter: SnowflakeRecordFormatter =
when (snowflakeConfiguration.legacyRawTablesOnly) {
true -> SnowflakeRawRecordFormatter(columns, snowflakeColumnUtils)
else -> SnowflakeSchemaRecordFormatter(columns, snowflakeColumnUtils)
}
fun accumulate(recordFields: Map<String, AirbyteValue>) { fun accumulate(recordFields: Map<String, AirbyteValue>) {
if (csvFilePath == null) { if (csvFilePath == null) {
val csvFile = createCsvFile() val csvFile = createCsvFile()
@@ -92,7 +88,9 @@ class SnowflakeInsertBuffer(
"Copying staging data into ${tableName.toPrettyString(quote = QUOTE)}..." "Copying staging data into ${tableName.toPrettyString(quote = QUOTE)}..."
} }
// Finally, copy the data from the staging table to the final table // Finally, copy the data from the staging table to the final table
snowflakeClient.copyFromStage(tableName, filePath.fileName.toString()) // Pass column names to ensure correct mapping even after ALTER TABLE operations
val columnNames = columnManager.getTableColumnNames(columnSchema)
snowflakeClient.copyFromStage(tableName, filePath.fileName.toString(), columnNames)
logger.info { logger.info {
"Finished insert of $recordCount row(s) into ${tableName.toPrettyString(quote = QUOTE)}." "Finished insert of $recordCount row(s) into ${tableName.toPrettyString(quote = QUOTE)}."
} }
@@ -117,7 +115,9 @@ class SnowflakeInsertBuffer(
private fun writeToCsvFile(record: Map<String, AirbyteValue>) { private fun writeToCsvFile(record: Map<String, AirbyteValue>) {
csvWriter?.let { csvWriter?.let {
it.writeRecord(snowflakeRecordFormatter.format(record).map { col -> col.toString() }) it.writeRecord(
snowflakeRecordFormatter.format(record, columnSchema).map { col -> col.toString() }
)
recordCount++ recordCount++
if ((recordCount % flushLimit) == 0) { if ((recordCount % flushLimit) == 0) {
it.flush() it.flush()

View File

@@ -8,99 +8,62 @@ import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.NullValue import io.airbyte.cdk.load.data.NullValue
import io.airbyte.cdk.load.data.StringValue import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.data.csv.toCsvValue import io.airbyte.cdk.load.data.csv.toCsvValue
import io.airbyte.cdk.load.message.Meta import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.util.Jsons import io.airbyte.cdk.load.util.Jsons
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
interface SnowflakeRecordFormatter { interface SnowflakeRecordFormatter {
fun format(record: Map<String, AirbyteValue>): List<Any> fun format(record: Map<String, AirbyteValue>, columnSchema: ColumnSchema): List<Any>
} }
class SnowflakeSchemaRecordFormatter( class SnowflakeSchemaRecordFormatter : SnowflakeRecordFormatter {
private val columns: LinkedHashMap<String, String>, override fun format(record: Map<String, AirbyteValue>, columnSchema: ColumnSchema): List<Any> {
val snowflakeColumnUtils: SnowflakeColumnUtils, val result = mutableListOf<Any>()
) : SnowflakeRecordFormatter { val userColumns = columnSchema.finalSchema.keys
private val airbyteColumnNames = // WARNING: MUST match the order defined in SnowflakeColumnManager#getTableColumnNames
snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet() //
// Why don't we just use that here? Well, unlike the user fields, the meta fields on the
// record are not munged for the destination. So we must access the values for those columns
// using the original lowercase meta key.
result.add(record[COLUMN_NAME_AB_RAW_ID].toCsvValue())
result.add(record[COLUMN_NAME_AB_EXTRACTED_AT].toCsvValue())
result.add(record[COLUMN_NAME_AB_META].toCsvValue())
result.add(record[COLUMN_NAME_AB_GENERATION_ID].toCsvValue())
override fun format(record: Map<String, AirbyteValue>): List<Any> = // Add user columns from the final schema
columns.map { (columnName, _) -> userColumns.forEach { columnName -> result.add(record[columnName].toCsvValue()) }
/*
* Meta columns are forced to uppercase for backwards compatibility with previous return result
* versions of the destination. Therefore, convert the column to lowercase so }
* that it can match the constants, which use the lowercase version of the meta
* column names.
*/
if (airbyteColumnNames.contains(columnName)) {
record[columnName.lowercase()].toCsvValue()
} else {
record.keys
// The columns retrieved from Snowflake do not have any escaping applied.
// Therefore, re-apply the compatible name escaping to the name of the
// columns retrieved from Snowflake. The record keys should already have
// been escaped by the CDK before arriving at the aggregate, so no need
// to escape again here.
.find { it == columnName.toSnowflakeCompatibleName() }
?.let { record[it].toCsvValue() }
?: ""
}
}
} }
class SnowflakeRawRecordFormatter( class SnowflakeRawRecordFormatter : SnowflakeRecordFormatter {
columns: LinkedHashMap<String, String>,
val snowflakeColumnUtils: SnowflakeColumnUtils,
) : SnowflakeRecordFormatter {
private val columns = columns.keys
private val airbyteColumnNames = override fun format(record: Map<String, AirbyteValue>, columnSchema: ColumnSchema): List<Any> =
snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet()
override fun format(record: Map<String, AirbyteValue>): List<Any> =
toOutputRecord(record.toMutableMap()) toOutputRecord(record.toMutableMap())
private fun toOutputRecord(record: MutableMap<String, AirbyteValue>): List<Any> { private fun toOutputRecord(record: MutableMap<String, AirbyteValue>): List<Any> {
val outputRecord = mutableListOf<Any>() val outputRecord = mutableListOf<Any>()
// Copy the Airbyte metadata columns to the raw output, removing each val mutableRecord = record.toMutableMap()
// one from the record to avoid duplicates in the "data" field
columns // Add meta columns in order (except _airbyte_data which we handle specially)
.filter { airbyteColumnNames.contains(it) && it != Meta.COLUMN_NAME_DATA } outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_RAW_ID)?.toCsvValue() ?: "")
.forEach { column -> safeAddToOutput(column, record, outputRecord) } outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_EXTRACTED_AT)?.toCsvValue() ?: "")
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_META)?.toCsvValue() ?: "")
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_GENERATION_ID)?.toCsvValue() ?: "")
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_LOADED_AT)?.toCsvValue() ?: "")
// Do not output null values in the JSON raw output // Do not output null values in the JSON raw output
val filteredRecord = record.filter { (_, v) -> v !is NullValue } val filteredRecord = mutableRecord.filter { (_, v) -> v !is NullValue }
// Convert all the remaining columns in the record to a JSON document stored in the "data"
// column. Add it in the same position as the _airbyte_data column in the column list to // Convert all the remaining columns to a JSON document stored in the "data" column
// ensure it is inserted into the proper column in the table. outputRecord.add(StringValue(Jsons.writeValueAsString(filteredRecord)).toCsvValue())
insert(
columns.indexOf(Meta.COLUMN_NAME_DATA),
StringValue(Jsons.writeValueAsString(filteredRecord)).toCsvValue(),
outputRecord
)
return outputRecord return outputRecord
} }
private fun safeAddToOutput(
key: String,
record: MutableMap<String, AirbyteValue>,
output: MutableList<Any>
) {
val extractedValue = record.remove(key)
// Ensure that the data is inserted into the list at the same position as the column
insert(columns.indexOf(key), extractedValue?.toCsvValue() ?: "", output)
}
private fun insert(index: Int, value: Any, list: MutableList<Any>) {
/*
* Attempt to insert the value into the proper order in the list. If the index
* is already present in the list, use the add(index, element) method to insert it
* into the proper order and push everything to the right. If the index is at the
* end of the list, just use add(element) to insert it at the end. If the index
* is further beyond the end of the list, throw an exception as that should not occur.
*/
if (index < list.size) list.add(index, value)
else if (index == list.size || index == list.size + 1) list.add(value)
else throw IndexOutOfBoundsException()
}
} }

View File

@@ -1,25 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.write.transform
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.dataflow.transform.ColumnNameMapper
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import jakarta.inject.Singleton
@Singleton
class SnowflakeColumnNameMapper(
private val catalogInfo: TableCatalog,
private val snowflakeConfiguration: SnowflakeConfiguration,
) : ColumnNameMapper {
override fun getMappedColumnName(stream: DestinationStream, columnName: String): String {
if (snowflakeConfiguration.legacyRawTablesOnly == true) {
return columnName
} else {
return catalogInfo.getMappedColumnName(stream, columnName)!!
}
}
}

View File

@@ -7,9 +7,11 @@ package io.airbyte.integrations.destination.snowflake.component
import io.airbyte.cdk.load.component.TableOperationsFixtures import io.airbyte.cdk.load.component.TableOperationsFixtures
import io.airbyte.cdk.load.component.TableOperationsSuite import io.airbyte.cdk.load.component.TableOperationsSuite
import io.airbyte.cdk.load.message.Meta import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.TableSchemaFactory
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.idTestWithCdcMapping import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.idTestWithCdcMapping
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.testMapping import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.testMapping
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeTestTableOperationsClient
import io.micronaut.test.extensions.junit5.annotation.MicronautTest import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.junit.jupiter.api.parallel.Execution import org.junit.jupiter.api.parallel.Execution
@@ -20,6 +22,7 @@ import org.junit.jupiter.api.parallel.ExecutionMode
class SnowflakeTableOperationsTest( class SnowflakeTableOperationsTest(
override val client: SnowflakeAirbyteClient, override val client: SnowflakeAirbyteClient,
override val testClient: SnowflakeTestTableOperationsClient, override val testClient: SnowflakeTestTableOperationsClient,
override val schemaFactory: TableSchemaFactory,
) : TableOperationsSuite { ) : TableOperationsSuite {
override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() } override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() }

View File

@@ -9,13 +9,15 @@ import io.airbyte.cdk.load.component.TableSchemaEvolutionFixtures
import io.airbyte.cdk.load.component.TableSchemaEvolutionSuite import io.airbyte.cdk.load.component.TableSchemaEvolutionSuite
import io.airbyte.cdk.load.data.StringValue import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.message.Meta import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.TableSchemaFactory
import io.airbyte.cdk.load.table.ColumnNameMapping import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.util.serializeToString import io.airbyte.cdk.load.util.serializeToString
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.allTypesColumnNameMapping import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.allTypesColumnNameMapping
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.allTypesTableSchema import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.allTypesTableSchema
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.idAndTestMapping import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.idAndTestMapping
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.testMapping import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.testMapping
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeTestTableOperationsClient
import io.micronaut.test.extensions.junit5.annotation.MicronautTest import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.junit.jupiter.api.parallel.Execution import org.junit.jupiter.api.parallel.Execution
@@ -27,6 +29,7 @@ class SnowflakeTableSchemaEvolutionTest(
override val client: SnowflakeAirbyteClient, override val client: SnowflakeAirbyteClient,
override val opsClient: SnowflakeAirbyteClient, override val opsClient: SnowflakeAirbyteClient,
override val testClient: SnowflakeTestTableOperationsClient, override val testClient: SnowflakeTestTableOperationsClient,
override val schemaFactory: TableSchemaFactory,
) : TableSchemaEvolutionSuite { ) : TableSchemaEvolutionSuite {
override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() } override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() }

View File

@@ -2,7 +2,7 @@
* Copyright (c) 2025 Airbyte, Inc., all rights reserved. * Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/ */
package io.airbyte.integrations.destination.snowflake.component package io.airbyte.integrations.destination.snowflake.component.config
import io.airbyte.cdk.load.component.ColumnType import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.component.TableOperationsFixtures import io.airbyte.cdk.load.component.TableOperationsFixtures
@@ -33,6 +33,8 @@ object SnowflakeComponentTestFixtures {
"TIME_NTZ" to ColumnType("TIME", true), "TIME_NTZ" to ColumnType("TIME", true),
"ARRAY" to ColumnType("ARRAY", true), "ARRAY" to ColumnType("ARRAY", true),
"OBJECT" to ColumnType("OBJECT", true), "OBJECT" to ColumnType("OBJECT", true),
"UNION" to ColumnType("VARIANT", true),
"LEGACY_UNION" to ColumnType("VARIANT", true),
"UNKNOWN" to ColumnType("VARIANT", true), "UNKNOWN" to ColumnType("VARIANT", true),
) )
) )

View File

@@ -2,7 +2,7 @@
* Copyright (c) 2025 Airbyte, Inc., all rights reserved. * Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/ */
package io.airbyte.integrations.destination.snowflake.component package io.airbyte.integrations.destination.snowflake.component.config
import io.airbyte.cdk.load.component.config.TestConfigLoader.loadTestConfig import io.airbyte.cdk.load.component.config.TestConfigLoader.loadTestConfig
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration

View File

@@ -2,22 +2,24 @@
* Copyright (c) 2025 Airbyte, Inc., all rights reserved. * Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/ */
package io.airbyte.integrations.destination.snowflake.component package io.airbyte.integrations.destination.snowflake.component.config
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.component.TestTableOperationsClient import io.airbyte.cdk.load.component.TestTableOperationsClient
import io.airbyte.cdk.load.data.AirbyteValue import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.dataflow.state.PartitionKey import io.airbyte.cdk.load.dataflow.state.PartitionKey
import io.airbyte.cdk.load.dataflow.transform.RecordDTO import io.airbyte.cdk.load.dataflow.transform.RecordDTO
import io.airbyte.cdk.load.table.TableName import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.util.Jsons import io.airbyte.cdk.load.util.Jsons
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.client.execute import io.airbyte.integrations.destination.snowflake.client.execute
import io.airbyte.integrations.destination.snowflake.dataflow.SnowflakeAggregate import io.airbyte.integrations.destination.snowflake.dataflow.SnowflakeAggregate
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils
import io.airbyte.integrations.destination.snowflake.sql.andLog import io.airbyte.integrations.destination.snowflake.sql.andLog
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
import io.micronaut.context.annotation.Requires import io.micronaut.context.annotation.Requires
import jakarta.inject.Singleton import jakarta.inject.Singleton
import java.time.format.DateTimeFormatter import java.time.format.DateTimeFormatter
@@ -29,25 +31,40 @@ import net.snowflake.client.jdbc.SnowflakeTimestampWithTimezone
class SnowflakeTestTableOperationsClient( class SnowflakeTestTableOperationsClient(
private val client: SnowflakeAirbyteClient, private val client: SnowflakeAirbyteClient,
private val dataSource: DataSource, private val dataSource: DataSource,
private val snowflakeSqlNameUtils: SnowflakeSqlNameUtils, private val sqlGenerator: SnowflakeDirectLoadSqlGenerator,
private val snowflakeColumnUtils: SnowflakeColumnUtils,
private val snowflakeConfiguration: SnowflakeConfiguration, private val snowflakeConfiguration: SnowflakeConfiguration,
private val columnManager: SnowflakeColumnManager,
private val snowflakeRecordFormatter: SnowflakeRecordFormatter,
) : TestTableOperationsClient { ) : TestTableOperationsClient {
override suspend fun dropNamespace(namespace: String) { override suspend fun dropNamespace(namespace: String) {
dataSource.execute( dataSource.execute(
"DROP SCHEMA IF EXISTS ${snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)}".andLog() "DROP SCHEMA IF EXISTS ${sqlGenerator.fullyQualifiedNamespace(namespace)}".andLog()
) )
} }
override suspend fun insertRecords(table: TableName, records: List<Map<String, AirbyteValue>>) { override suspend fun insertRecords(table: TableName, records: List<Map<String, AirbyteValue>>) {
// TODO: we should just pass a proper column schema
// Since we don't pass in a proper column schema, we have to recreate one here
// Fetch the columns and filter out the meta columns so we're just looking at user columns
val columnTypes =
client.describeTable(table).filterNot {
columnManager.getMetaColumnNames().contains(it.key)
}
val columnSchema =
io.airbyte.cdk.load.schema.model.ColumnSchema(
inputToFinalColumnNames = columnTypes.keys.associateWith { it },
finalSchema = columnTypes.mapValues { (_, _) -> ColumnType("", true) },
inputSchema = emptyMap() // Not needed for insert buffer
)
val a = val a =
SnowflakeAggregate( SnowflakeAggregate(
SnowflakeInsertBuffer( SnowflakeInsertBuffer(
table, tableName = table,
client.describeTable(table), snowflakeClient = client,
client, snowflakeConfiguration = snowflakeConfiguration,
snowflakeConfiguration, columnSchema = columnSchema,
snowflakeColumnUtils, columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
) )
) )
records.forEach { a.accept(RecordDTO(it, PartitionKey(""), 0, 0)) } records.forEach { a.accept(RecordDTO(it, PartitionKey(""), 0, 0)) }

View File

@@ -7,11 +7,11 @@ package io.airbyte.integrations.destination.snowflake.write
import io.airbyte.cdk.load.test.util.DestinationCleaner import io.airbyte.cdk.load.test.util.DestinationCleaner
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier
import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory
import io.airbyte.integrations.destination.snowflake.sql.STAGE_NAME_PREFIX import io.airbyte.integrations.destination.snowflake.sql.STAGE_NAME_PREFIX
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
import io.airbyte.integrations.destination.snowflake.sql.quote import io.airbyte.integrations.destination.snowflake.sql.quote
import java.nio.file.Files import java.nio.file.Files
import java.sql.Connection import java.sql.Connection

View File

@@ -12,16 +12,22 @@ import io.airbyte.cdk.load.data.ObjectValue
import io.airbyte.cdk.load.data.json.toAirbyteValue import io.airbyte.cdk.load.data.json.toAirbyteValue
import io.airbyte.cdk.load.message.Meta import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator
import io.airbyte.cdk.load.test.util.DestinationDataDumper import io.airbyte.cdk.load.test.util.DestinationDataDumper
import io.airbyte.cdk.load.test.util.OutputRecord import io.airbyte.cdk.load.test.util.OutputRecord
import io.airbyte.cdk.load.util.UUIDGenerator
import io.airbyte.cdk.load.util.deserializeToNode import io.airbyte.cdk.load.util.deserializeToNode
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
import io.airbyte.integrations.destination.snowflake.db.SnowflakeFinalTableNameGenerator import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.schema.SnowflakeTableSchemaMapper
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
import io.airbyte.integrations.destination.snowflake.sql.sqlEscape import io.airbyte.integrations.destination.snowflake.sql.sqlEscape
import java.math.BigDecimal import java.math.BigDecimal
import java.sql.Date
import java.sql.Time
import java.sql.Timestamp
import net.snowflake.client.jdbc.SnowflakeTimestampWithTimezone import net.snowflake.client.jdbc.SnowflakeTimestampWithTimezone
private val AIRBYTE_META_COLUMNS = Meta.COLUMN_NAMES + setOf(CDC_DELETED_AT_COLUMN) private val AIRBYTE_META_COLUMNS = Meta.COLUMN_NAMES + setOf(CDC_DELETED_AT_COLUMN)
@@ -34,8 +40,14 @@ class SnowflakeDataDumper(
stream: DestinationStream stream: DestinationStream
): List<OutputRecord> { ): List<OutputRecord> {
val config = configProvider(spec) val config = configProvider(spec)
val sqlUtils = SnowflakeSqlNameUtils(config) val snowflakeFinalTableNameGenerator =
val snowflakeFinalTableNameGenerator = SnowflakeFinalTableNameGenerator(config) SnowflakeTableSchemaMapper(
config = config,
tempTableNameGenerator = DefaultTempTableNameGenerator(),
)
val snowflakeColumnManager = SnowflakeColumnManager(config)
val sqlGenerator =
SnowflakeDirectLoadSqlGenerator(UUIDGenerator(), config, snowflakeColumnManager)
val dataSource = val dataSource =
SnowflakeBeanFactory() SnowflakeBeanFactory()
.snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY") .snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY")
@@ -46,7 +58,7 @@ class SnowflakeDataDumper(
ds.connection.use { connection -> ds.connection.use { connection ->
val statement = connection.createStatement() val statement = connection.createStatement()
val tableName = val tableName =
snowflakeFinalTableNameGenerator.getTableName(stream.mappedDescriptor) snowflakeFinalTableNameGenerator.toFinalTableName(stream.mappedDescriptor)
// First check if the table exists // First check if the table exists
val tableExistsQuery = val tableExistsQuery =
@@ -69,7 +81,7 @@ class SnowflakeDataDumper(
val resultSet = val resultSet =
statement.executeQuery( statement.executeQuery(
"SELECT * FROM ${sqlUtils.fullyQualifiedName(tableName)}" "SELECT * FROM ${sqlGenerator.fullyQualifiedName(tableName)}"
) )
while (resultSet.next()) { while (resultSet.next()) {
@@ -143,10 +155,10 @@ class SnowflakeDataDumper(
private fun convertValue(value: Any?): Any? = private fun convertValue(value: Any?): Any? =
when (value) { when (value) {
is BigDecimal -> value.toBigInteger() is BigDecimal -> value.toBigInteger()
is java.sql.Date -> value.toLocalDate() is Date -> value.toLocalDate()
is SnowflakeTimestampWithTimezone -> value.toZonedDateTime() is SnowflakeTimestampWithTimezone -> value.toZonedDateTime()
is java.sql.Time -> value.toLocalTime() is Time -> value.toLocalTime()
is java.sql.Timestamp -> value.toLocalDateTime() is Timestamp -> value.toLocalDateTime()
else -> value else -> value
} }
} }

View File

@@ -15,7 +15,7 @@ import io.airbyte.cdk.load.dataflow.transform.ValidationResult
import io.airbyte.cdk.load.message.Meta import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.test.util.ExpectedRecordMapper import io.airbyte.cdk.load.test.util.ExpectedRecordMapper
import io.airbyte.cdk.load.test.util.OutputRecord import io.airbyte.cdk.load.test.util.OutputRecord
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.write.transform.SnowflakeValueCoercer import io.airbyte.integrations.destination.snowflake.write.transform.SnowflakeValueCoercer
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Change import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Change

View File

@@ -5,7 +5,7 @@
package io.airbyte.integrations.destination.snowflake.write package io.airbyte.integrations.destination.snowflake.write
import io.airbyte.cdk.load.test.util.NameMapper import io.airbyte.cdk.load.test.util.NameMapper
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
class SnowflakeNameMapper : NameMapper { class SnowflakeNameMapper : NameMapper {
override fun mapFieldName(path: List<String>): List<String> = override fun mapFieldName(path: List<String>): List<String> =

View File

@@ -9,13 +9,16 @@ import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.data.StringValue import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.data.json.toAirbyteValue import io.airbyte.cdk.load.data.json.toAirbyteValue
import io.airbyte.cdk.load.message.Meta import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator
import io.airbyte.cdk.load.test.util.DestinationDataDumper import io.airbyte.cdk.load.test.util.DestinationDataDumper
import io.airbyte.cdk.load.test.util.OutputRecord import io.airbyte.cdk.load.test.util.OutputRecord
import io.airbyte.cdk.load.util.UUIDGenerator
import io.airbyte.cdk.load.util.deserializeToNode import io.airbyte.cdk.load.util.deserializeToNode
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
import io.airbyte.integrations.destination.snowflake.db.SnowflakeFinalTableNameGenerator import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeTableSchemaMapper
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
class SnowflakeRawDataDumper( class SnowflakeRawDataDumper(
private val configProvider: (ConfigurationSpecification) -> SnowflakeConfiguration private val configProvider: (ConfigurationSpecification) -> SnowflakeConfiguration
@@ -27,8 +30,18 @@ class SnowflakeRawDataDumper(
val output = mutableListOf<OutputRecord>() val output = mutableListOf<OutputRecord>()
val config = configProvider(spec) val config = configProvider(spec)
val sqlUtils = SnowflakeSqlNameUtils(config) val snowflakeColumnManager = SnowflakeColumnManager(config)
val snowflakeFinalTableNameGenerator = SnowflakeFinalTableNameGenerator(config) val sqlGenerator =
SnowflakeDirectLoadSqlGenerator(
UUIDGenerator(),
config,
snowflakeColumnManager,
)
val snowflakeFinalTableNameGenerator =
SnowflakeTableSchemaMapper(
config = config,
tempTableNameGenerator = DefaultTempTableNameGenerator(),
)
val dataSource = val dataSource =
SnowflakeBeanFactory() SnowflakeBeanFactory()
.snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY") .snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY")
@@ -37,11 +50,11 @@ class SnowflakeRawDataDumper(
ds.connection.use { connection -> ds.connection.use { connection ->
val statement = connection.createStatement() val statement = connection.createStatement()
val tableName = val tableName =
snowflakeFinalTableNameGenerator.getTableName(stream.mappedDescriptor) snowflakeFinalTableNameGenerator.toFinalTableName(stream.mappedDescriptor)
val resultSet = val resultSet =
statement.executeQuery( statement.executeQuery(
"SELECT * FROM ${sqlUtils.fullyQualifiedName(tableName)}" "SELECT * FROM ${sqlGenerator.fullyQualifiedName(tableName)}"
) )
while (resultSet.next()) { while (resultSet.next()) {

View File

@@ -6,7 +6,7 @@ package io.airbyte.integrations.destination.snowflake
import com.zaxxer.hikari.HikariConfig import com.zaxxer.hikari.HikariConfig
import com.zaxxer.hikari.HikariDataSource import com.zaxxer.hikari.HikariDataSource
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode
import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration

View File

@@ -5,11 +5,9 @@
package io.airbyte.integrations.destination.snowflake.check package io.airbyte.integrations.destination.snowflake.check
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.DEFAULT_COLUMNS
import io.airbyte.integrations.destination.snowflake.sql.RAW_DATA_COLUMN
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.mockk.coEvery import io.mockk.coEvery
import io.mockk.coVerify import io.mockk.coVerify
import io.mockk.every import io.mockk.every
@@ -23,47 +21,31 @@ internal class SnowflakeCheckerTest {
@ParameterizedTest @ParameterizedTest
@ValueSource(booleans = [true, false]) @ValueSource(booleans = [true, false])
fun testSuccessfulCheck(isLegacyRawTablesOnly: Boolean) { fun testSuccessfulCheck(isLegacyRawTablesOnly: Boolean) {
val defaultColumnsMap =
if (isLegacyRawTablesOnly) {
linkedMapOf<String, String>().also { map ->
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
map[it.columnName] = it.columnType
}
}
} else {
linkedMapOf<String, String>().also { map ->
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
map[it.columnName.toSnowflakeCompatibleName()] = it.columnType
}
}
}
val defaultColumns = defaultColumnsMap.keys.toMutableList()
val snowflakeAirbyteClient: SnowflakeAirbyteClient = val snowflakeAirbyteClient: SnowflakeAirbyteClient =
mockk(relaxed = true) { mockk(relaxed = true) { coEvery { countTable(any()) } returns 1L }
coEvery { countTable(any()) } returns 1L
coEvery { describeTable(any()) } returns defaultColumnsMap
}
val testSchema = "test-schema" val testSchema = "test-schema"
val snowflakeConfiguration: SnowflakeConfiguration = mockk { val snowflakeConfiguration: SnowflakeConfiguration = mockk {
every { schema } returns testSchema every { schema } returns testSchema
every { legacyRawTablesOnly } returns isLegacyRawTablesOnly every { legacyRawTablesOnly } returns isLegacyRawTablesOnly
} }
val snowflakeColumnUtils =
mockk<SnowflakeColumnUtils>(relaxUnitFun = true) { val columnManager: SnowflakeColumnManager = SnowflakeColumnManager(snowflakeConfiguration)
every { getFormattedDefaultColumnNames(any()) } returns defaultColumns
}
val checker = val checker =
SnowflakeChecker( SnowflakeChecker(
snowflakeAirbyteClient = snowflakeAirbyteClient, snowflakeAirbyteClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration, snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils, columnManager = columnManager,
) )
checker.check() checker.check()
coVerify(exactly = 1) { coVerify(exactly = 1) {
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName()) if (isLegacyRawTablesOnly) {
snowflakeAirbyteClient.createNamespace(testSchema)
} else {
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
}
} }
coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) } coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) }
coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) } coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) }
@@ -72,48 +54,32 @@ internal class SnowflakeCheckerTest {
@ParameterizedTest @ParameterizedTest
@ValueSource(booleans = [true, false]) @ValueSource(booleans = [true, false])
fun testUnsuccessfulCheck(isLegacyRawTablesOnly: Boolean) { fun testUnsuccessfulCheck(isLegacyRawTablesOnly: Boolean) {
val defaultColumnsMap =
if (isLegacyRawTablesOnly) {
linkedMapOf<String, String>().also { map ->
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
map[it.columnName] = it.columnType
}
}
} else {
linkedMapOf<String, String>().also { map ->
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
map[it.columnName.toSnowflakeCompatibleName()] = it.columnType
}
}
}
val defaultColumns = defaultColumnsMap.keys.toMutableList()
val snowflakeAirbyteClient: SnowflakeAirbyteClient = val snowflakeAirbyteClient: SnowflakeAirbyteClient =
mockk(relaxed = true) { mockk(relaxed = true) { coEvery { countTable(any()) } returns 0L }
coEvery { countTable(any()) } returns 0L
coEvery { describeTable(any()) } returns defaultColumnsMap
}
val testSchema = "test-schema" val testSchema = "test-schema"
val snowflakeConfiguration: SnowflakeConfiguration = mockk { val snowflakeConfiguration: SnowflakeConfiguration = mockk {
every { schema } returns testSchema every { schema } returns testSchema
every { legacyRawTablesOnly } returns isLegacyRawTablesOnly every { legacyRawTablesOnly } returns isLegacyRawTablesOnly
} }
val snowflakeColumnUtils =
mockk<SnowflakeColumnUtils>(relaxUnitFun = true) { val columnManager: SnowflakeColumnManager = SnowflakeColumnManager(snowflakeConfiguration)
every { getFormattedDefaultColumnNames(any()) } returns defaultColumns
}
val checker = val checker =
SnowflakeChecker( SnowflakeChecker(
snowflakeAirbyteClient = snowflakeAirbyteClient, snowflakeAirbyteClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration, snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils, columnManager = columnManager,
) )
assertThrows<IllegalArgumentException> { checker.check() } assertThrows<IllegalArgumentException> { checker.check() }
coVerify(exactly = 1) { coVerify(exactly = 1) {
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName()) if (isLegacyRawTablesOnly) {
snowflakeAirbyteClient.createNamespace(testSchema)
} else {
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
}
} }
coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) } coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) }
coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) } coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) }

View File

@@ -6,24 +6,16 @@ package io.airbyte.integrations.destination.snowflake.client
import io.airbyte.cdk.ConfigErrorException import io.airbyte.cdk.ConfigErrorException
import io.airbyte.cdk.load.command.DestinationStream import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.command.NamespaceMapper
import io.airbyte.cdk.load.command.Overwrite
import io.airbyte.cdk.load.component.ColumnType import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.config.NamespaceDefinitionType
import io.airbyte.cdk.load.data.AirbyteType
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.ColumnNameMapping import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TableName import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.db.ColumnDefinition import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS
import io.airbyte.integrations.destination.snowflake.sql.ColumnAndType
import io.airbyte.integrations.destination.snowflake.sql.DEFAULT_COLUMNS
import io.airbyte.integrations.destination.snowflake.sql.QUOTE import io.airbyte.integrations.destination.snowflake.sql.QUOTE
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
import io.mockk.Runs import io.mockk.Runs
import io.mockk.every import io.mockk.every
@@ -49,31 +41,18 @@ internal class SnowflakeAirbyteClientTest {
private lateinit var client: SnowflakeAirbyteClient private lateinit var client: SnowflakeAirbyteClient
private lateinit var dataSource: DataSource private lateinit var dataSource: DataSource
private lateinit var sqlGenerator: SnowflakeDirectLoadSqlGenerator private lateinit var sqlGenerator: SnowflakeDirectLoadSqlGenerator
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
private lateinit var snowflakeConfiguration: SnowflakeConfiguration private lateinit var snowflakeConfiguration: SnowflakeConfiguration
private lateinit var columnManager: SnowflakeColumnManager
@BeforeEach @BeforeEach
fun setup() { fun setup() {
dataSource = mockk() dataSource = mockk()
sqlGenerator = mockk(relaxed = true) sqlGenerator = mockk(relaxed = true)
snowflakeColumnUtils =
mockk(relaxed = true) {
every { formatColumnName(any()) } answers
{
firstArg<String>().toSnowflakeCompatibleName()
}
every { getFormattedDefaultColumnNames(any()) } returns
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() }
}
snowflakeConfiguration = snowflakeConfiguration =
mockk(relaxed = true) { every { database } returns "test_database" } mockk(relaxed = true) { every { database } returns "test_database" }
columnManager = mockk(relaxed = true)
client = client =
SnowflakeAirbyteClient( SnowflakeAirbyteClient(dataSource, sqlGenerator, snowflakeConfiguration, columnManager)
dataSource,
sqlGenerator,
snowflakeColumnUtils,
snowflakeConfiguration
)
} }
@Test @Test
@@ -231,7 +210,7 @@ internal class SnowflakeAirbyteClientTest {
@Test @Test
fun testCreateTable() { fun testCreateTable() {
val columnNameMapping = mockk<ColumnNameMapping>(relaxed = true) val columnNameMapping = mockk<ColumnNameMapping>(relaxed = true)
val stream = mockk<DestinationStream>() val stream = mockk<DestinationStream>(relaxed = true)
val tableName = TableName(namespace = "namespace", name = "name") val tableName = TableName(namespace = "namespace", name = "name")
val resultSet = mockk<ResultSet>(relaxed = true) val resultSet = mockk<ResultSet>(relaxed = true)
val statement = val statement =
@@ -254,9 +233,7 @@ internal class SnowflakeAirbyteClientTest {
columnNameMapping = columnNameMapping, columnNameMapping = columnNameMapping,
replace = true, replace = true,
) )
verify(exactly = 1) { verify(exactly = 1) { sqlGenerator.createTable(tableName, any(), true) }
sqlGenerator.createTable(stream, tableName, columnNameMapping, true)
}
verify(exactly = 1) { sqlGenerator.createSnowflakeStage(tableName) } verify(exactly = 1) { sqlGenerator.createSnowflakeStage(tableName) }
verify(exactly = 2) { mockConnection.close() } verify(exactly = 2) { mockConnection.close() }
} }
@@ -288,7 +265,7 @@ internal class SnowflakeAirbyteClientTest {
targetTableName = destinationTableName, targetTableName = destinationTableName,
) )
verify(exactly = 1) { verify(exactly = 1) {
sqlGenerator.copyTable(columnNameMapping, sourceTableName, destinationTableName) sqlGenerator.copyTable(any<Set<String>>(), sourceTableName, destinationTableName)
} }
verify(exactly = 1) { mockConnection.close() } verify(exactly = 1) { mockConnection.close() }
} }
@@ -299,7 +276,7 @@ internal class SnowflakeAirbyteClientTest {
val columnNameMapping = mockk<ColumnNameMapping>(relaxed = true) val columnNameMapping = mockk<ColumnNameMapping>(relaxed = true)
val sourceTableName = TableName(namespace = "namespace", name = "source") val sourceTableName = TableName(namespace = "namespace", name = "source")
val destinationTableName = TableName(namespace = "namespace", name = "destination") val destinationTableName = TableName(namespace = "namespace", name = "destination")
val stream = mockk<DestinationStream>() val stream = mockk<DestinationStream>(relaxed = true)
val resultSet = mockk<ResultSet>(relaxed = true) val resultSet = mockk<ResultSet>(relaxed = true)
val statement = val statement =
mockk<Statement> { mockk<Statement> {
@@ -322,12 +299,7 @@ internal class SnowflakeAirbyteClientTest {
targetTableName = destinationTableName, targetTableName = destinationTableName,
) )
verify(exactly = 1) { verify(exactly = 1) {
sqlGenerator.upsertTable( sqlGenerator.upsertTable(any(), sourceTableName, destinationTableName)
stream,
columnNameMapping,
sourceTableName,
destinationTableName
)
} }
verify(exactly = 1) { mockConnection.close() } verify(exactly = 1) { mockConnection.close() }
} }
@@ -379,7 +351,7 @@ internal class SnowflakeAirbyteClientTest {
} }
every { dataSource.connection } returns mockConnection every { dataSource.connection } returns mockConnection
every { snowflakeColumnUtils.getGenerationIdColumnName() } returns generationIdColumnName every { columnManager.getGenerationIdColumnName() } returns generationIdColumnName
every { sqlGenerator.getGenerationId(tableName) } returns every { sqlGenerator.getGenerationId(tableName) } returns
"SELECT $generationIdColumnName FROM ${tableName.toPrettyString(QUOTE)}" "SELECT $generationIdColumnName FROM ${tableName.toPrettyString(QUOTE)}"
@@ -501,8 +473,8 @@ internal class SnowflakeAirbyteClientTest {
every { dataSource.connection } returns mockConnection every { dataSource.connection } returns mockConnection
runBlocking { runBlocking {
client.copyFromStage(tableName, "test.csv.gz") client.copyFromStage(tableName, "test.csv.gz", listOf())
verify(exactly = 1) { sqlGenerator.copyFromStage(tableName, "test.csv.gz") } verify(exactly = 1) { sqlGenerator.copyFromStage(tableName, "test.csv.gz", listOf()) }
verify(exactly = 1) { mockConnection.close() } verify(exactly = 1) { mockConnection.close() }
} }
} }
@@ -556,7 +528,7 @@ internal class SnowflakeAirbyteClientTest {
"COL1" andThen "COL1" andThen
COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName() andThen COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName() andThen
"COL2" "COL2"
every { resultSet.getString("type") } returns "VARCHAR(255)" andThen "NUMBER(38,0)" every { resultSet.getString("type") } returns "VARCHAR(255)" andThen "NUMBER"
every { resultSet.getString("null?") } returns "Y" andThen "N" andThen "N" every { resultSet.getString("null?") } returns "Y" andThen "N" andThen "N"
val statement = val statement =
@@ -571,6 +543,10 @@ internal class SnowflakeAirbyteClientTest {
every { dataSource.connection } returns connection every { dataSource.connection } returns connection
// Mock the columnManager to return the correct set of meta columns
every { columnManager.getMetaColumnNames() } returns
setOf(COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName())
val result = client.getColumnsFromDb(tableName) val result = client.getColumnsFromDb(tableName)
val expectedColumns = val expectedColumns =
@@ -582,81 +558,6 @@ internal class SnowflakeAirbyteClientTest {
assertEquals(expectedColumns, result) assertEquals(expectedColumns, result)
} }
@Test
fun `getColumnsFromStream should return correct column definitions`() {
val schema = mockk<AirbyteType>()
val stream =
DestinationStream(
unmappedNamespace = "test_namespace",
unmappedName = "test_stream",
importType = Overwrite,
schema = schema,
generationId = 1,
minimumGenerationId = 1,
syncId = 1,
namespaceMapper = NamespaceMapper(NamespaceDefinitionType.DESTINATION)
)
val columnNameMapping =
ColumnNameMapping(
mapOf(
"col1" to "COL1_MAPPED",
"col2" to "COL2_MAPPED",
)
)
val col1FieldType = mockk<FieldType>()
every { col1FieldType.type } returns mockk()
val col2FieldType = mockk<FieldType>()
every { col2FieldType.type } returns mockk()
every { schema.asColumns() } returns
linkedMapOf("col1" to col1FieldType, "col2" to col2FieldType)
every { snowflakeColumnUtils.toDialectType(col1FieldType.type) } returns "VARCHAR(255)"
every { snowflakeColumnUtils.toDialectType(col2FieldType.type) } returns "NUMBER(38,0)"
every { snowflakeColumnUtils.columnsAndTypes(any(), any()) } returns
listOf(ColumnAndType("COL1_MAPPED", "VARCHAR"), ColumnAndType("COL2_MAPPED", "NUMBER"))
every { snowflakeColumnUtils.formatColumnName(any(), false) } answers
{
firstArg<String>().toSnowflakeCompatibleName()
}
val result = client.getColumnsFromStream(stream, columnNameMapping)
val expectedColumns =
mapOf(
"COL1_MAPPED" to ColumnType("VARCHAR", true),
"COL2_MAPPED" to ColumnType("NUMBER", true),
)
assertEquals(expectedColumns, result)
}
@Test
fun `generateSchemaChanges should correctly identify changes`() {
val columnsInDb =
setOf(
ColumnDefinition("COL1", "VARCHAR"),
ColumnDefinition("COL2", "NUMBER"),
ColumnDefinition("COL3", "BOOLEAN")
)
val columnsInStream =
setOf(
ColumnDefinition("COL1", "VARCHAR"), // Unchanged
ColumnDefinition("COL3", "TEXT"), // Modified
ColumnDefinition("COL4", "DATE") // Added
)
val (added, deleted, modified) = client.generateSchemaChanges(columnsInDb, columnsInStream)
assertEquals(1, added.size)
assertEquals("COL4", added.first().name)
assertEquals(1, deleted.size)
assertEquals("COL2", deleted.first().name)
assertEquals(1, modified.size)
assertEquals("COL3", modified.first().name)
}
@Test @Test
fun testCreateNamespaceWithNetworkFailure() { fun testCreateNamespaceWithNetworkFailure() {
val namespace = "test_namespace" val namespace = "test_namespace"

View File

@@ -4,14 +4,19 @@
package io.airbyte.integrations.destination.snowflake.dataflow package io.airbyte.integrations.destination.snowflake.dataflow
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.command.DestinationStream.Descriptor import io.airbyte.cdk.load.command.DestinationStream.Descriptor
import io.airbyte.cdk.load.dataflow.aggregate.StoreKey import io.airbyte.cdk.load.dataflow.aggregate.StoreKey
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.TableName import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.write.StreamStateStore import io.airbyte.cdk.load.write.StreamStateStore
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRawRecordFormatter
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter
import io.mockk.every import io.mockk.every
import io.mockk.mockk import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertEquals
@@ -22,27 +27,33 @@ internal class SnowflakeAggregateFactoryTest {
@Test @Test
fun testCreatingAggregateWithRawBuffer() { fun testCreatingAggregateWithRawBuffer() {
val descriptor = Descriptor(namespace = "namespace", name = "name") val descriptor = Descriptor(namespace = "namespace", name = "name")
val directLoadTableExecutionConfig = val tableName =
DirectLoadTableExecutionConfig( TableName(
tableName = namespace = descriptor.namespace!!,
TableName( name = descriptor.name,
namespace = descriptor.namespace!!,
name = descriptor.name,
)
) )
val directLoadTableExecutionConfig = DirectLoadTableExecutionConfig(tableName = tableName)
val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name) val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name)
val streamStore = StreamStateStore<DirectLoadTableExecutionConfig>() val streamStore = StreamStateStore<DirectLoadTableExecutionConfig>()
streamStore.put(descriptor, directLoadTableExecutionConfig) streamStore.put(key, directLoadTableExecutionConfig)
val stream = mockk<DestinationStream>(relaxed = true)
val catalog = mockk<DestinationCatalog> { every { getStream(key) } returns stream }
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val snowflakeConfiguration = val snowflakeConfiguration =
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns true } mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns true }
val snowflakeColumnUtils = mockk<SnowflakeColumnUtils>(relaxed = true) val columnManager = SnowflakeColumnManager(snowflakeConfiguration)
val snowflakeRecordFormatter: SnowflakeRecordFormatter = SnowflakeRawRecordFormatter()
val factory = val factory =
SnowflakeAggregateFactory( SnowflakeAggregateFactory(
snowflakeClient = snowflakeClient, snowflakeClient = snowflakeClient,
streamStateStore = streamStore, streamStateStore = streamStore,
snowflakeConfiguration = snowflakeConfiguration, snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils, catalog = catalog,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
) )
val aggregate = factory.create(key) val aggregate = factory.create(key)
assertNotNull(aggregate) assertNotNull(aggregate)
@@ -52,26 +63,33 @@ internal class SnowflakeAggregateFactoryTest {
@Test @Test
fun testCreatingAggregateWithStagingBuffer() { fun testCreatingAggregateWithStagingBuffer() {
val descriptor = Descriptor(namespace = "namespace", name = "name") val descriptor = Descriptor(namespace = "namespace", name = "name")
val directLoadTableExecutionConfig = val tableName =
DirectLoadTableExecutionConfig( TableName(
tableName = namespace = descriptor.namespace!!,
TableName( name = descriptor.name,
namespace = descriptor.namespace!!,
name = descriptor.name,
)
) )
val directLoadTableExecutionConfig = DirectLoadTableExecutionConfig(tableName = tableName)
val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name) val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name)
val streamStore = StreamStateStore<DirectLoadTableExecutionConfig>() val streamStore = StreamStateStore<DirectLoadTableExecutionConfig>()
streamStore.put(descriptor, directLoadTableExecutionConfig) streamStore.put(key, directLoadTableExecutionConfig)
val stream = mockk<DestinationStream>(relaxed = true)
val catalog = mockk<DestinationCatalog> { every { getStream(key) } returns stream }
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val snowflakeConfiguration = mockk<SnowflakeConfiguration>(relaxed = true) val snowflakeConfiguration =
val snowflakeColumnUtils = mockk<SnowflakeColumnUtils>(relaxed = true) mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns false }
val columnManager = SnowflakeColumnManager(snowflakeConfiguration)
val snowflakeRecordFormatter: SnowflakeRecordFormatter = SnowflakeSchemaRecordFormatter()
val factory = val factory =
SnowflakeAggregateFactory( SnowflakeAggregateFactory(
snowflakeClient = snowflakeClient, snowflakeClient = snowflakeClient,
streamStateStore = streamStore, streamStateStore = streamStore,
snowflakeConfiguration = snowflakeConfiguration, snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils, catalog = catalog,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
) )
val aggregate = factory.create(key) val aggregate = factory.create(key)
assertNotNull(aggregate) assertNotNull(aggregate)

View File

@@ -1,23 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
import io.mockk.every
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
internal class SnowflakeColumnNameGeneratorTest {
@Test
fun testGetColumnName() {
val column = "test-column"
val generator =
SnowflakeColumnNameGenerator(mockk { every { legacyRawTablesOnly } returns false })
val columnName = generator.getColumnName(column)
assertEquals(column.toSnowflakeCompatibleName(), columnName.displayName)
assertEquals(column.toSnowflakeCompatibleName(), columnName.canonicalName)
}
}

View File

@@ -1,72 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.mockk.every
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
internal class SnowflakeFinalTableNameGeneratorTest {
@Test
fun testGetTableNameWithInternalNamespace() {
val configuration =
mockk<SnowflakeConfiguration> {
every { internalTableSchema } returns "test-internal-namespace"
every { legacyRawTablesOnly } returns true
}
val generator = SnowflakeFinalTableNameGenerator(config = configuration)
val streamName = "test-stream-name"
val streamNamespace = "test-stream-namespace"
val streamDescriptor =
mockk<DestinationStream.Descriptor> {
every { namespace } returns streamNamespace
every { name } returns streamName
}
val tableName = generator.getTableName(streamDescriptor)
assertEquals("test-stream-namespace_raw__stream_test-stream-name", tableName.name)
assertEquals("test-internal-namespace", tableName.namespace)
}
@Test
fun testGetTableNameWithNamespace() {
val configuration =
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns false }
val generator = SnowflakeFinalTableNameGenerator(config = configuration)
val streamName = "test-stream-name"
val streamNamespace = "test-stream-namespace"
val streamDescriptor =
mockk<DestinationStream.Descriptor> {
every { namespace } returns streamNamespace
every { name } returns streamName
}
val tableName = generator.getTableName(streamDescriptor)
assertEquals("TEST-STREAM-NAME", tableName.name)
assertEquals("TEST-STREAM-NAMESPACE", tableName.namespace)
}
@Test
fun testGetTableNameWithDefaultNamespace() {
val defaultNamespace = "test-default-namespace"
val configuration =
mockk<SnowflakeConfiguration> {
every { schema } returns defaultNamespace
every { legacyRawTablesOnly } returns false
}
val generator = SnowflakeFinalTableNameGenerator(config = configuration)
val streamName = "test-stream-name"
val streamDescriptor =
mockk<DestinationStream.Descriptor> {
every { namespace } returns null
every { name } returns streamName
}
val tableName = generator.getTableName(streamDescriptor)
assertEquals("TEST-STREAM-NAME", tableName.name)
assertEquals("TEST-DEFAULT-NAMESPACE", tableName.namespace)
}
}

View File

@@ -1,27 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.CsvSource
internal class SnowflakeNameGeneratorsTest {
@ParameterizedTest
@CsvSource(
value =
[
"test-name,TEST-NAME",
"1-test-name,1-TEST-NAME",
"test-name!!!,TEST-NAME!!!",
"test\${name,TEST__NAME",
"test\"name,TEST\"\"NAME",
]
)
fun testToSnowflakeCompatibleName(name: String, expected: String) {
assertEquals(expected, name.toSnowflakeCompatibleName())
}
}

View File

@@ -1,358 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
import com.fasterxml.jackson.databind.JsonNode
import io.airbyte.cdk.load.data.ArrayType
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
import io.airbyte.cdk.load.data.BooleanType
import io.airbyte.cdk.load.data.DateType
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerType
import io.airbyte.cdk.load.data.NumberType
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
import io.airbyte.cdk.load.data.UnionType
import io.airbyte.cdk.load.data.UnknownType
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA
import io.airbyte.cdk.load.orchestration.db.ColumnNameGenerator
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.integrations.destination.snowflake.db.SnowflakeColumnNameGenerator
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.mockk.every
import io.mockk.mockk
import kotlin.collections.LinkedHashMap
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.CsvSource
internal class SnowflakeColumnUtilsTest {
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
private lateinit var snowflakeColumnNameGenerator: SnowflakeColumnNameGenerator
@BeforeEach
fun setup() {
snowflakeConfiguration = mockk(relaxed = true)
snowflakeColumnNameGenerator =
mockk(relaxed = true) {
every { getColumnName(any()) } answers
{
val displayName =
if (snowflakeConfiguration.legacyRawTablesOnly) firstArg<String>()
else firstArg<String>().toSnowflakeCompatibleName()
val canonicalName =
if (snowflakeConfiguration.legacyRawTablesOnly) firstArg<String>()
else firstArg<String>().toSnowflakeCompatibleName()
ColumnNameGenerator.ColumnName(
displayName = displayName,
canonicalName = canonicalName,
)
}
}
snowflakeColumnUtils =
SnowflakeColumnUtils(snowflakeConfiguration, snowflakeColumnNameGenerator)
}
@Test
fun testDefaultColumns() {
val expectedDefaultColumns = DEFAULT_COLUMNS
assertEquals(expectedDefaultColumns, snowflakeColumnUtils.defaultColumns())
}
@Test
fun testDefaultRawColumns() {
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val expectedDefaultColumns = DEFAULT_COLUMNS + RAW_COLUMNS
assertEquals(expectedDefaultColumns, snowflakeColumnUtils.defaultColumns())
}
@Test
fun testGetFormattedDefaultColumnNames() {
val expectedDefaultColumnNames =
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() }
val defaultColumnNames = snowflakeColumnUtils.getFormattedDefaultColumnNames()
assertEquals(expectedDefaultColumnNames, defaultColumnNames)
}
@Test
fun testGetFormattedDefaultColumnNamesQuoted() {
val expectedDefaultColumnNames =
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName().quote() }
val defaultColumnNames = snowflakeColumnUtils.getFormattedDefaultColumnNames(true)
assertEquals(expectedDefaultColumnNames, defaultColumnNames)
}
@Test
fun testGetColumnName() {
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val columnNames = snowflakeColumnUtils.getColumnNames(columnNameMapping)
val expectedColumnNames =
(DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() } + listOf("actual"))
.joinToString(",") { it.quote() }
assertEquals(expectedColumnNames, columnNames)
}
@Test
fun testGetRawColumnName() {
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val columnNames = snowflakeColumnUtils.getColumnNames(columnNameMapping)
val expectedColumnNames =
(DEFAULT_COLUMNS.map { it.columnName } + RAW_COLUMNS.map { it.columnName })
.joinToString(",") { it.quote() }
assertEquals(expectedColumnNames, columnNames)
}
@Test
fun testGetRawFormattedColumnNames() {
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val schemaColumns =
mapOf(
"column_one" to FieldType(StringType, true),
"column_two" to FieldType(IntegerType, true),
"original" to FieldType(StringType, true),
CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true)
)
val expectedColumnNames =
DEFAULT_COLUMNS.map { it.columnName.quote() } +
RAW_COLUMNS.map { it.columnName.quote() }
val columnNames =
snowflakeColumnUtils.getFormattedColumnNames(
columns = schemaColumns,
columnNameMapping = columnNameMapping
)
assertEquals(expectedColumnNames.size, columnNames.size)
assertEquals(expectedColumnNames.sorted(), columnNames.sorted())
}
@Test
fun testGetFormattedColumnNames() {
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val schemaColumns =
mapOf(
"column_one" to FieldType(StringType, true),
"column_two" to FieldType(IntegerType, true),
"original" to FieldType(StringType, true),
CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true)
)
val expectedColumnNames =
listOf(
"actual",
"column_one",
"column_two",
CDC_DELETED_AT_COLUMN,
)
.map { it.quote() } +
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName().quote() }
val columnNames =
snowflakeColumnUtils.getFormattedColumnNames(
columns = schemaColumns,
columnNameMapping = columnNameMapping
)
assertEquals(expectedColumnNames.size, columnNames.size)
assertEquals(expectedColumnNames.sorted(), columnNames.sorted())
}
@Test
fun testGetFormattedColumnNamesNoQuotes() {
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val schemaColumns =
mapOf(
"column_one" to FieldType(StringType, true),
"column_two" to FieldType(IntegerType, true),
"original" to FieldType(StringType, true),
CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true)
)
val expectedColumnNames =
listOf(
"actual",
"column_one",
"column_two",
CDC_DELETED_AT_COLUMN,
) + DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() }
val columnNames =
snowflakeColumnUtils.getFormattedColumnNames(
columns = schemaColumns,
columnNameMapping = columnNameMapping,
quote = false
)
assertEquals(expectedColumnNames.size, columnNames.size)
assertEquals(expectedColumnNames.sorted(), columnNames.sorted())
}
@Test
fun testGeneratingRawTableColumnsAndTypesNoColumnMapping() {
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val columns =
snowflakeColumnUtils.columnsAndTypes(
columns = emptyMap(),
columnNameMapping = ColumnNameMapping(emptyMap())
)
assertEquals(DEFAULT_COLUMNS.size + RAW_COLUMNS.size, columns.size)
assertEquals(
"${SnowflakeDataType.VARIANT.typeName} $NOT_NULL",
columns.find { it.columnName == RAW_DATA_COLUMN.columnName }?.columnType
)
}
@Test
fun testGeneratingColumnsAndTypesNoColumnMapping() {
val columnName = "test-column"
val fieldType = FieldType(StringType, false)
val declaredColumns = mapOf(columnName to fieldType)
val columns =
snowflakeColumnUtils.columnsAndTypes(
columns = declaredColumns,
columnNameMapping = ColumnNameMapping(emptyMap())
)
assertEquals(DEFAULT_COLUMNS.size + 1, columns.size)
assertEquals(
"${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL",
columns.find { it.columnName == columnName }?.columnType
)
}
@Test
fun testGeneratingColumnsAndTypesWithColumnMapping() {
val columnName = "test-column"
val mappedColumnName = "mapped-column-name"
val fieldType = FieldType(StringType, false)
val declaredColumns = mapOf(columnName to fieldType)
val columnNameMapping = ColumnNameMapping(mapOf(columnName to mappedColumnName))
val columns =
snowflakeColumnUtils.columnsAndTypes(
columns = declaredColumns,
columnNameMapping = columnNameMapping
)
assertEquals(DEFAULT_COLUMNS.size + 1, columns.size)
assertEquals(
"${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL",
columns.find { it.columnName == mappedColumnName }?.columnType
)
}
@Test
fun testToDialectType() {
assertEquals(
SnowflakeDataType.BOOLEAN.typeName,
snowflakeColumnUtils.toDialectType(BooleanType)
)
assertEquals(SnowflakeDataType.DATE.typeName, snowflakeColumnUtils.toDialectType(DateType))
assertEquals(
SnowflakeDataType.NUMBER.typeName,
snowflakeColumnUtils.toDialectType(IntegerType)
)
assertEquals(
SnowflakeDataType.FLOAT.typeName,
snowflakeColumnUtils.toDialectType(NumberType)
)
assertEquals(
SnowflakeDataType.VARCHAR.typeName,
snowflakeColumnUtils.toDialectType(StringType)
)
assertEquals(
SnowflakeDataType.VARCHAR.typeName,
snowflakeColumnUtils.toDialectType(TimeTypeWithTimezone)
)
assertEquals(
SnowflakeDataType.TIME.typeName,
snowflakeColumnUtils.toDialectType(TimeTypeWithoutTimezone)
)
assertEquals(
SnowflakeDataType.TIMESTAMP_TZ.typeName,
snowflakeColumnUtils.toDialectType(TimestampTypeWithTimezone)
)
assertEquals(
SnowflakeDataType.TIMESTAMP_NTZ.typeName,
snowflakeColumnUtils.toDialectType(TimestampTypeWithoutTimezone)
)
assertEquals(
SnowflakeDataType.ARRAY.typeName,
snowflakeColumnUtils.toDialectType(ArrayType(items = FieldType(StringType, false)))
)
assertEquals(
SnowflakeDataType.ARRAY.typeName,
snowflakeColumnUtils.toDialectType(ArrayTypeWithoutSchema)
)
assertEquals(
SnowflakeDataType.OBJECT.typeName,
snowflakeColumnUtils.toDialectType(
ObjectType(
properties = LinkedHashMap(),
additionalProperties = false,
)
)
)
assertEquals(
SnowflakeDataType.OBJECT.typeName,
snowflakeColumnUtils.toDialectType(ObjectTypeWithEmptySchema)
)
assertEquals(
SnowflakeDataType.OBJECT.typeName,
snowflakeColumnUtils.toDialectType(ObjectTypeWithoutSchema)
)
assertEquals(
SnowflakeDataType.VARIANT.typeName,
snowflakeColumnUtils.toDialectType(
UnionType(
options = setOf(StringType),
isLegacyUnion = true,
)
)
)
assertEquals(
SnowflakeDataType.VARIANT.typeName,
snowflakeColumnUtils.toDialectType(
UnionType(
options = emptySet(),
isLegacyUnion = false,
)
)
)
assertEquals(
SnowflakeDataType.VARIANT.typeName,
snowflakeColumnUtils.toDialectType(UnknownType(schema = mockk<JsonNode>()))
)
}
@ParameterizedTest
@CsvSource(
value =
[
"$COLUMN_NAME_DATA, true, \"$COLUMN_NAME_DATA\"",
"some-other_Column, true, \"SOME-OTHER_COLUMN\"",
"$COLUMN_NAME_DATA, false, $COLUMN_NAME_DATA",
"some-other_Column, false, SOME-OTHER_COLUMN",
"$COLUMN_NAME_DATA, true, \"$COLUMN_NAME_DATA\"",
"some-other_Column, true, \"SOME-OTHER_COLUMN\"",
]
)
fun testFormatColumnName(columnName: String, quote: Boolean, expectedFormattedName: String) {
assertEquals(
expectedFormattedName,
snowflakeColumnUtils.formatColumnName(columnName, quote)
)
}
}

View File

@@ -1,97 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
import io.airbyte.cdk.load.table.TableName
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.mockk.every
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
internal class SnowflakeSqlNameUtilsTest {
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
private lateinit var snowflakeSqlNameUtils: SnowflakeSqlNameUtils
@BeforeEach
fun setUp() {
snowflakeConfiguration = mockk(relaxed = true)
snowflakeSqlNameUtils =
SnowflakeSqlNameUtils(snowflakeConfiguration = snowflakeConfiguration)
}
@Test
fun testFullyQualifiedName() {
val databaseName = "test-database"
val namespace = "test-namespace"
val name = "test=name"
val tableName = TableName(namespace = namespace, name = name)
every { snowflakeConfiguration.database } returns databaseName
val expectedName =
snowflakeSqlNameUtils.combineParts(
listOf(
databaseName.toSnowflakeCompatibleName(),
tableName.namespace,
tableName.name
)
)
val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedName(tableName)
assertEquals(expectedName, fullyQualifiedName)
}
@Test
fun testFullyQualifiedNamespace() {
val databaseName = "test-database"
val namespace = "test-namespace"
every { snowflakeConfiguration.database } returns databaseName
val fullyQualifiedNamespace = snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)
assertEquals("\"TEST-DATABASE\".\"test-namespace\"", fullyQualifiedNamespace)
}
@Test
fun testFullyQualifiedStageName() {
val databaseName = "test-database"
val namespace = "test-namespace"
val name = "test=name"
val tableName = TableName(namespace = namespace, name = name)
every { snowflakeConfiguration.database } returns databaseName
val expectedName =
snowflakeSqlNameUtils.combineParts(
listOf(
databaseName.toSnowflakeCompatibleName(),
namespace,
"$STAGE_NAME_PREFIX$name"
)
)
val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName)
assertEquals(expectedName, fullyQualifiedName)
}
@Test
fun testFullyQualifiedStageNameWithEscape() {
val databaseName = "test-database"
val namespace = "test-namespace"
val name = "test=\"\"\'name"
val tableName = TableName(namespace = namespace, name = name)
every { snowflakeConfiguration.database } returns databaseName
val expectedName =
snowflakeSqlNameUtils.combineParts(
listOf(
databaseName.toSnowflakeCompatibleName(),
namespace,
"$STAGE_NAME_PREFIX${sqlEscape(name)}"
)
)
val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true)
assertEquals(expectedName, fullyQualifiedName)
}
}

View File

@@ -6,20 +6,21 @@ package io.airbyte.integrations.destination.snowflake.write
import io.airbyte.cdk.SystemErrorException import io.airbyte.cdk.SystemErrorException
import io.airbyte.cdk.load.command.Append import io.airbyte.cdk.load.command.Append
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationStream import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.orchestration.db.DatabaseInitialStatusGatherer import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.orchestration.db.TableNames import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadInitialStatus import io.airbyte.cdk.load.schema.model.TableNames
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendStreamLoader import io.airbyte.cdk.load.table.DatabaseInitialStatusGatherer
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendTruncateStreamLoader import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableStatus import io.airbyte.cdk.load.table.directload.DirectLoadInitialStatus
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendStreamLoader
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableNameInfo import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendTruncateStreamLoader
import io.airbyte.cdk.load.table.ColumnNameMapping import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.table.TableName import io.airbyte.cdk.load.table.directload.DirectLoadTableStatus
import io.airbyte.cdk.load.write.StreamStateStore
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.mockk.coEvery import io.mockk.coEvery
import io.mockk.coVerify import io.mockk.coVerify
import io.mockk.every import io.mockk.every
@@ -34,55 +35,93 @@ internal class SnowflakeWriterTest {
@Test @Test
fun testSetup() { fun testSetup() {
val tableName = TableName(namespace = "test-namespace", name = "test-name") val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName) val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
val stream = mockk<DestinationStream>() val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val tableInfo = val stream =
TableNameInfo( mockk<DestinationStream> {
tableNames = tableNames, every { tableSchema } returns
columnNameMapping = ColumnNameMapping(emptyMap()) StreamTableSchema(
) tableNames = tableNames,
val catalog = TableCatalog(mapOf(stream to tableInfo)) columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
every { importType } returns Append
}
val catalog = DestinationCatalog(listOf(stream))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer = val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> { mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns emptyMap() coEvery { gatherInitialStatus() } returns
mapOf(
stream to
DirectLoadInitialStatus(
realTable = DirectLoadTableStatus(false),
tempTable = null
)
)
} }
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer = val writer =
SnowflakeWriter( SnowflakeWriter(
names = catalog, catalog = catalog,
stateGatherer = stateGatherer, stateGatherer = stateGatherer,
streamStateStore = mockk(), streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient, snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(), tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(relaxed = true), snowflakeConfiguration =
mockk(relaxed = true) {
every { internalTableSchema } returns "internal_schema"
},
) )
runBlocking { writer.setup() } runBlocking { writer.setup() }
coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) } coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) }
coVerify(exactly = 1) { stateGatherer.gatherInitialStatus(catalog) } coVerify(exactly = 1) { stateGatherer.gatherInitialStatus() }
} }
@Test @Test
fun testCreateStreamLoaderFirstGeneration() { fun testCreateStreamLoaderFirstGeneration() {
val tableName = TableName(namespace = "test-namespace", name = "test-name") val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName) val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream = val stream =
mockk<DestinationStream> { mockk<DestinationStream> {
every { minimumGenerationId } returns 0L every { minimumGenerationId } returns 0L
every { generationId } returns 0L every { generationId } returns 0L
every { importType } returns Append every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
} }
val tableInfo = val catalog = DestinationCatalog(listOf(stream))
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer = val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> { mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns coEvery { gatherInitialStatus() } returns
mapOf( mapOf(
stream to stream to
DirectLoadInitialStatus( DirectLoadInitialStatus(
@@ -93,14 +132,18 @@ internal class SnowflakeWriterTest {
} }
val tempTableNameGenerator = val tempTableNameGenerator =
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } } mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer = val writer =
SnowflakeWriter( SnowflakeWriter(
names = catalog, catalog = catalog,
stateGatherer = stateGatherer, stateGatherer = stateGatherer,
streamStateStore = mockk(), streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient, snowflakeClient = snowflakeClient,
tempTableNameGenerator = tempTableNameGenerator, tempTableNameGenerator = tempTableNameGenerator,
snowflakeConfiguration = mockk(relaxed = true), snowflakeConfiguration =
mockk(relaxed = true) {
every { internalTableSchema } returns "internal_schema"
},
) )
runBlocking { runBlocking {
@@ -113,23 +156,35 @@ internal class SnowflakeWriterTest {
@Test @Test
fun testCreateStreamLoaderNotFirstGeneration() { fun testCreateStreamLoaderNotFirstGeneration() {
val tableName = TableName(namespace = "test-namespace", name = "test-name") val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName) val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream = val stream =
mockk<DestinationStream> { mockk<DestinationStream> {
every { minimumGenerationId } returns 1L every { minimumGenerationId } returns 1L
every { generationId } returns 1L every { generationId } returns 1L
every { importType } returns Append every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
} }
val tableInfo = val catalog = DestinationCatalog(listOf(stream))
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer = val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> { mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns coEvery { gatherInitialStatus() } returns
mapOf( mapOf(
stream to stream to
DirectLoadInitialStatus( DirectLoadInitialStatus(
@@ -140,14 +195,18 @@ internal class SnowflakeWriterTest {
} }
val tempTableNameGenerator = val tempTableNameGenerator =
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } } mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer = val writer =
SnowflakeWriter( SnowflakeWriter(
names = catalog, catalog = catalog,
stateGatherer = stateGatherer, stateGatherer = stateGatherer,
streamStateStore = mockk(), streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient, snowflakeClient = snowflakeClient,
tempTableNameGenerator = tempTableNameGenerator, tempTableNameGenerator = tempTableNameGenerator,
snowflakeConfiguration = mockk(relaxed = true), snowflakeConfiguration =
mockk(relaxed = true) {
every { internalTableSchema } returns "internal_schema"
},
) )
runBlocking { runBlocking {
@@ -160,22 +219,35 @@ internal class SnowflakeWriterTest {
@Test @Test
fun testCreateStreamLoaderHybrid() { fun testCreateStreamLoaderHybrid() {
val tableName = TableName(namespace = "test-namespace", name = "test-name") val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName) val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream = val stream =
mockk<DestinationStream> { mockk<DestinationStream> {
every { minimumGenerationId } returns 1L every { minimumGenerationId } returns 1L
every { generationId } returns 2L every { generationId } returns 2L
every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
} }
val tableInfo = val catalog = DestinationCatalog(listOf(stream))
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer = val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> { mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns coEvery { gatherInitialStatus() } returns
mapOf( mapOf(
stream to stream to
DirectLoadInitialStatus( DirectLoadInitialStatus(
@@ -186,14 +258,18 @@ internal class SnowflakeWriterTest {
} }
val tempTableNameGenerator = val tempTableNameGenerator =
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } } mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer = val writer =
SnowflakeWriter( SnowflakeWriter(
names = catalog, catalog = catalog,
stateGatherer = stateGatherer, stateGatherer = stateGatherer,
streamStateStore = mockk(), streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient, snowflakeClient = snowflakeClient,
tempTableNameGenerator = tempTableNameGenerator, tempTableNameGenerator = tempTableNameGenerator,
snowflakeConfiguration = mockk(relaxed = true), snowflakeConfiguration =
mockk(relaxed = true) {
every { internalTableSchema } returns "internal_schema"
},
) )
runBlocking { runBlocking {
@@ -203,169 +279,126 @@ internal class SnowflakeWriterTest {
} }
@Test @Test
fun testSetupWithNamespaceCreationFailure() { fun testCreateStreamLoaderNamespaceLegacy() {
val tableName = TableName(namespace = "test-namespace", name = "test-name") val namespace = "test-namespace"
val tableNames = TableNames(rawTableName = null, finalTableName = tableName) val name = "test-name"
val stream = mockk<DestinationStream>() val tableName = TableName(namespace = namespace, name = name)
val tableInfo = val tempTableName = TableName(namespace = namespace, name = "${name}-temp")
TableNameInfo( val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val snowflakeClient = mockk<SnowflakeAirbyteClient>()
val stateGatherer = mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>>()
val writer =
SnowflakeWriter(
names = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(),
)
// Simulate network failure during namespace creation
coEvery {
snowflakeClient.createNamespace(tableName.namespace.toSnowflakeCompatibleName())
} throws RuntimeException("Network connection failed")
assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } }
}
@Test
fun testSetupWithInitialStatusGatheringFailure() {
val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
val stream = mockk<DestinationStream>()
val tableInfo =
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer = mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>>()
val writer =
SnowflakeWriter(
names = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(),
)
// Simulate failure while gathering initial status
coEvery { stateGatherer.gatherInitialStatus(catalog) } throws
RuntimeException("Failed to query table status")
assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } }
// Verify namespace creation was still attempted
coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) }
}
@Test
fun testCreateStreamLoaderWithMissingInitialStatus() {
val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
val stream = val stream =
mockk<DestinationStream> { mockk<DestinationStream> {
every { minimumGenerationId } returns 0L every { minimumGenerationId } returns 0L
every { generationId } returns 0L every { generationId } returns 0L
every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
} }
val missingStream = val catalog = DestinationCatalog(listOf(stream))
mockk<DestinationStream> {
every { minimumGenerationId } returns 0L
every { generationId } returns 0L
}
val tableInfo =
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer = val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> { mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns coEvery { gatherInitialStatus() } returns
mapOf( mapOf(
stream to stream to
DirectLoadInitialStatus( DirectLoadInitialStatus(
realTable = DirectLoadTableStatus(false), realTable = DirectLoadTableStatus(false),
tempTable = null, tempTable = null
) )
) )
} }
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer = val writer =
SnowflakeWriter( SnowflakeWriter(
names = catalog, catalog = catalog,
stateGatherer = stateGatherer, stateGatherer = stateGatherer,
streamStateStore = mockk(), streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient, snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(), tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(relaxed = true), snowflakeConfiguration =
mockk(relaxed = true) {
every { legacyRawTablesOnly } returns true
every { internalTableSchema } returns "internal_schema"
},
) )
runBlocking { runBlocking { writer.setup() }
writer.setup()
// Try to create loader for a stream that wasn't in initial status coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) }
assertThrows(NullPointerException::class.java) { }
writer.createStreamLoader(missingStream)
@Test
fun testCreateStreamLoaderNamespaceNonLegacy() {
val namespace = "test-namespace"
val name = "test-name"
val tableName = TableName(namespace = namespace, name = name)
val tempTableName = TableName(namespace = namespace, name = "${name}-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream =
mockk<DestinationStream> {
every { minimumGenerationId } returns 0L
every { generationId } returns 0L
every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
} }
} val catalog = DestinationCatalog(listOf(stream))
} val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer =
@Test mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
fun testCreateStreamLoaderWithNullFinalTableName() { coEvery { gatherInitialStatus() } returns
// TableNames constructor throws IllegalStateException when both names are null mapOf(
assertThrows(IllegalStateException::class.java) { stream to
TableNames(rawTableName = null, finalTableName = null) DirectLoadInitialStatus(
} realTable = DirectLoadTableStatus(false),
} tempTable = null
)
@Test )
fun testSetupWithMultipleNamespaceFailuresPartial() { }
val tableName1 = TableName(namespace = "namespace1", name = "table1") val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val tableName2 = TableName(namespace = "namespace2", name = "table2")
val tableNames1 = TableNames(rawTableName = null, finalTableName = tableName1)
val tableNames2 = TableNames(rawTableName = null, finalTableName = tableName2)
val stream1 = mockk<DestinationStream>()
val stream2 = mockk<DestinationStream>()
val tableInfo1 =
TableNameInfo(
tableNames = tableNames1,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val tableInfo2 =
TableNameInfo(
tableNames = tableNames2,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream1 to tableInfo1, stream2 to tableInfo2))
val snowflakeClient = mockk<SnowflakeAirbyteClient>()
val stateGatherer = mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>>()
val writer = val writer =
SnowflakeWriter( SnowflakeWriter(
names = catalog, catalog = catalog,
stateGatherer = stateGatherer, stateGatherer = stateGatherer,
streamStateStore = mockk(), streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient, snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(), tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(), snowflakeConfiguration =
mockk(relaxed = true) {
every { legacyRawTablesOnly } returns false
every { internalTableSchema } returns "internal_schema"
},
) )
// First namespace succeeds, second fails (namespaces are uppercased by runBlocking { writer.setup() }
// toSnowflakeCompatibleName)
coEvery { snowflakeClient.createNamespace("namespace1") } returns Unit
coEvery { snowflakeClient.createNamespace("namespace2") } throws
RuntimeException("Connection timeout")
assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } } coVerify(exactly = 1) { snowflakeClient.createNamespace(namespace) }
// Verify both namespace creations were attempted
coVerify(exactly = 1) { snowflakeClient.createNamespace("namespace1") }
coVerify(exactly = 1) { snowflakeClient.createNamespace("namespace2") }
} }
} }

View File

@@ -4,218 +4,220 @@
package io.airbyte.integrations.destination.snowflake.write.load package io.airbyte.integrations.destination.snowflake.write.load
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.data.AirbyteValue import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerValue import io.airbyte.cdk.load.data.IntegerValue
import io.airbyte.cdk.load.data.NullValue import io.airbyte.cdk.load.data.NullValue
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.StringValue import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.message.Meta import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.table.TableName import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.mockk.coVerify import io.mockk.coVerify
import io.mockk.every import io.mockk.every
import io.mockk.mockk import io.mockk.mockk
import java.io.BufferedReader import java.io.BufferedReader
import java.io.File
import java.io.InputStreamReader import java.io.InputStreamReader
import java.util.zip.GZIPInputStream import java.util.zip.GZIPInputStream
import kotlin.io.path.exists import kotlin.io.path.exists
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNotNull
import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
internal class SnowflakeInsertBufferTest { internal class SnowflakeInsertBufferTest {
private lateinit var snowflakeConfiguration: SnowflakeConfiguration private lateinit var snowflakeConfiguration: SnowflakeConfiguration
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils private lateinit var columnManager: SnowflakeColumnManager
private lateinit var columnSchema: ColumnSchema
private lateinit var snowflakeRecordFormatter: SnowflakeRecordFormatter
@BeforeEach @BeforeEach
fun setUp() { fun setUp() {
snowflakeConfiguration = mockk(relaxed = true) snowflakeConfiguration = mockk(relaxed = true)
snowflakeColumnUtils = mockk(relaxed = true) snowflakeRecordFormatter = SnowflakeSchemaRecordFormatter()
columnManager =
mockk(relaxed = true) {
every { getMetaColumns() } returns
linkedMapOf(
"_AIRBYTE_RAW_ID" to ColumnType("VARCHAR", false),
"_AIRBYTE_EXTRACTED_AT" to ColumnType("TIMESTAMP_TZ", false),
"_AIRBYTE_META" to ColumnType("VARIANT", false),
"_AIRBYTE_GENERATION_ID" to ColumnType("NUMBER", true)
)
every { getTableColumnNames(any()) } returns
listOf(
"_AIRBYTE_RAW_ID",
"_AIRBYTE_EXTRACTED_AT",
"_AIRBYTE_META",
"_AIRBYTE_GENERATION_ID",
"columnName"
)
}
} }
@Test @Test
fun testAccumulate() { fun testAccumulate() {
val tableName = mockk<TableName>(relaxed = true) val tableName = TableName(namespace = "test", name = "table")
val column = "columnName" val column = "columnName"
val columns = linkedMapOf(column to "NUMBER(38,0)") columnSchema =
ColumnSchema(
inputToFinalColumnNames = mapOf(column to column.uppercase()),
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord(column) val record = createRecord(column)
val buffer = val buffer =
SnowflakeInsertBuffer( SnowflakeInsertBuffer(
tableName = tableName, tableName = tableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient, snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration, snowflakeConfiguration = snowflakeConfiguration,
flushLimit = 1, columnSchema = columnSchema,
snowflakeColumnUtils = snowflakeColumnUtils, columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
) )
assertEquals(0, buffer.recordCount)
buffer.accumulate(record) runBlocking { buffer.accumulate(record) }
assertEquals(true, buffer.csvFilePath?.exists())
assertEquals(1, buffer.recordCount) assertEquals(1, buffer.recordCount)
} }
@Test @Test
fun testAccumulateRaw() { fun testFlushToStaging() {
val tableName = mockk<TableName>(relaxed = true) val tableName = TableName(namespace = "test", name = "table")
val column = "columnName" val column = "columnName"
val columns = linkedMapOf(column to "NUMBER(38,0)") columnSchema =
ColumnSchema(
inputToFinalColumnNames = mapOf(column to column.uppercase()),
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord(column) val record = createRecord(column)
val buffer = val buffer =
SnowflakeInsertBuffer( SnowflakeInsertBuffer(
tableName = tableName, tableName = tableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient, snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration, snowflakeConfiguration = snowflakeConfiguration,
columnSchema = columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
flushLimit = 1, flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
) )
val expectedColumnNames =
every { snowflakeConfiguration.legacyRawTablesOnly } returns true listOf(
"_AIRBYTE_RAW_ID",
buffer.accumulate(record) "_AIRBYTE_EXTRACTED_AT",
"_AIRBYTE_META",
assertEquals(true, buffer.csvFilePath?.exists()) "_AIRBYTE_GENERATION_ID",
assertEquals(1, buffer.recordCount) "columnName"
}
@Test
fun testFlush() {
val tableName = mockk<TableName>(relaxed = true)
val column = "columnName"
val columns = linkedMapOf(column to "NUMBER(38,0)")
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord(column)
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
) )
runBlocking { runBlocking {
buffer.accumulate(record) buffer.accumulate(record)
buffer.flush() buffer.flush()
} coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
coVerify(exactly = 1) {
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) } snowflakeAirbyteClient.copyFromStage(tableName, any(), expectedColumnNames)
coVerify(exactly = 1) { }
snowflakeAirbyteClient.copyFromStage(
tableName,
match { it.endsWith("$CSV_FILE_EXTENSION$FILE_SUFFIX") }
)
} }
} }
@Test @Test
fun testFlushRaw() { fun testFlushToNoStaging() {
val tableName = mockk<TableName>(relaxed = true) val tableName = TableName(namespace = "test", name = "table")
val column = "columnName" val column = "columnName"
val columns = linkedMapOf(column to "NUMBER(38,0)") columnSchema =
ColumnSchema(
inputToFinalColumnNames = mapOf(column to column.uppercase()),
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val record = createRecord(column)
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
columnSchema = columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
flushLimit = 1,
)
val expectedColumnNames =
listOf(
"_AIRBYTE_RAW_ID",
"_AIRBYTE_EXTRACTED_AT",
"_AIRBYTE_META",
"_AIRBYTE_GENERATION_ID",
"columnName"
)
runBlocking {
buffer.accumulate(record)
buffer.flush()
// In legacy raw mode, it still uses staging
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
coVerify(exactly = 1) {
snowflakeAirbyteClient.copyFromStage(tableName, any(), expectedColumnNames)
}
}
}
@Test
fun testFileCreation() {
val tableName = TableName(namespace = "test", name = "table")
val column = "columnName"
columnSchema =
ColumnSchema(
inputToFinalColumnNames = mapOf(column to column.uppercase()),
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord(column) val record = createRecord(column)
val buffer = val buffer =
SnowflakeInsertBuffer( SnowflakeInsertBuffer(
tableName = tableName, tableName = tableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient, snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration, snowflakeConfiguration = snowflakeConfiguration,
columnSchema = columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
flushLimit = 1, flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
) )
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
runBlocking { runBlocking {
buffer.accumulate(record) buffer.accumulate(record)
buffer.flush() // The csvFilePath is internal, we can access it for testing
} val filepath = buffer.csvFilePath
assertNotNull(filepath)
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) } val file = filepath!!.toFile()
coVerify(exactly = 1) { assert(file.exists())
snowflakeAirbyteClient.copyFromStage( // Close the writer to ensure all data is flushed
tableName,
match { it.endsWith("$CSV_FILE_EXTENSION$FILE_SUFFIX") }
)
}
}
@Test
fun testMissingFields() {
val tableName = mockk<TableName>(relaxed = true)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord("COLUMN1")
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
columns = linkedMapOf("COLUMN1" to "NUMBER(38,0)", "COLUMN2" to "NUMBER(38,0)"),
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
)
runBlocking {
buffer.accumulate(record)
buffer.csvWriter?.flush()
buffer.csvWriter?.close() buffer.csvWriter?.close()
assertEquals( val lines = mutableListOf<String>()
"test-value$CSV_FIELD_SEPARATOR$CSV_LINE_DELIMITER", GZIPInputStream(file.inputStream()).use { gzip ->
readFromCsvFile(buffer.csvFilePath!!.toFile()) BufferedReader(InputStreamReader(gzip)).use { bufferedReader ->
) bufferedReader.forEachLine { line -> lines.add(line) }
}
}
assertEquals(1, lines.size)
file.delete()
} }
} }
@Test private fun createRecord(column: String): Map<String, AirbyteValue> {
fun testMissingFieldsRaw() { return mapOf(
val tableName = mockk<TableName>(relaxed = true) column to IntegerValue(value = 42),
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true) Meta.COLUMN_NAME_AB_GENERATION_ID to NullValue,
val record = createRecord("COLUMN1") Meta.COLUMN_NAME_AB_RAW_ID to StringValue("raw-id-1"),
val buffer = Meta.COLUMN_NAME_AB_EXTRACTED_AT to IntegerValue(1234567890),
SnowflakeInsertBuffer( Meta.COLUMN_NAME_AB_META to StringValue("meta-data-foo"),
tableName = tableName,
columns = linkedMapOf("COLUMN1" to "NUMBER(38,0)", "COLUMN2" to "NUMBER(38,0)"),
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
)
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
runBlocking {
buffer.accumulate(record)
buffer.csvWriter?.flush()
buffer.csvWriter?.close()
assertEquals(
"test-value$CSV_FIELD_SEPARATOR$CSV_LINE_DELIMITER",
readFromCsvFile(buffer.csvFilePath!!.toFile())
)
}
}
private fun readFromCsvFile(file: File) =
GZIPInputStream(file.inputStream()).use { input ->
val reader = BufferedReader(InputStreamReader(input))
reader.readText()
}
private fun createRecord(columnName: String) =
mapOf(
columnName to AirbyteValue.from("test-value"),
Meta.COLUMN_NAME_AB_EXTRACTED_AT to IntegerValue(System.currentTimeMillis()),
Meta.COLUMN_NAME_AB_RAW_ID to StringValue("raw-id"),
Meta.COLUMN_NAME_AB_GENERATION_ID to IntegerValue(1223),
Meta.COLUMN_NAME_AB_META to StringValue("{\"changes\":[],\"syncId\":43}"),
"${columnName}Null" to NullValue
) )
}
} }

View File

@@ -15,12 +15,8 @@ import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.mockk.every
import io.mockk.mockk
import kotlin.collections.plus
import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
private val AIRBYTE_COLUMN_TYPES_MAP = private val AIRBYTE_COLUMN_TYPES_MAP =
@@ -58,28 +54,16 @@ private fun createExpected(
internal class SnowflakeRawRecordFormatterTest { internal class SnowflakeRawRecordFormatterTest {
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
@BeforeEach
fun setup() {
snowflakeColumnUtils = mockk {
every { getFormattedDefaultColumnNames(any()) } returns
AIRBYTE_COLUMN_TYPES_MAP.keys.toList()
}
}
@Test @Test
fun testFormatting() { fun testFormatting() {
val columnName = "test-column-name" val columnName = "test-column-name"
val columnValue = "test-column-value" val columnValue = "test-column-value"
val columns = AIRBYTE_COLUMN_TYPES_MAP val columns = AIRBYTE_COLUMN_TYPES_MAP
val record = createRecord(columnName = columnName, columnValue = columnValue) val record = createRecord(columnName = columnName, columnValue = columnValue)
val formatter = val formatter = SnowflakeRawRecordFormatter()
SnowflakeRawRecordFormatter( // RawRecordFormatter doesn't use columnSchema but still needs one per interface
columns = AIRBYTE_COLUMN_TYPES_MAP, val dummyColumnSchema = ColumnSchema(emptyMap(), emptyMap(), emptyMap())
snowflakeColumnUtils = snowflakeColumnUtils val formattedValue = formatter.format(record, dummyColumnSchema)
)
val formattedValue = formatter.format(record)
val expectedValue = val expectedValue =
createExpected( createExpected(
record = record, record = record,
@@ -93,33 +77,28 @@ internal class SnowflakeRawRecordFormatterTest {
fun testFormattingMigratedFromPreviousVersion() { fun testFormattingMigratedFromPreviousVersion() {
val columnName = "test-column-name" val columnName = "test-column-name"
val columnValue = "test-column-value" val columnValue = "test-column-value"
val columnsMap =
linkedMapOf(
COLUMN_NAME_AB_EXTRACTED_AT to "TIMESTAMP_TZ(9)",
COLUMN_NAME_AB_LOADED_AT to "TIMESTAMP_TZ(9)",
COLUMN_NAME_AB_META to "VARIANT",
COLUMN_NAME_DATA to "VARIANT",
COLUMN_NAME_AB_RAW_ID to "VARCHAR(16777216)",
COLUMN_NAME_AB_GENERATION_ID to "NUMBER(38,0)",
)
val record = createRecord(columnName = columnName, columnValue = columnValue) val record = createRecord(columnName = columnName, columnValue = columnValue)
val formatter = val formatter = SnowflakeRawRecordFormatter()
SnowflakeRawRecordFormatter( // RawRecordFormatter doesn't use columnSchema but still needs one per interface
columns = columnsMap, val dummyColumnSchema = ColumnSchema(emptyMap(), emptyMap(), emptyMap())
snowflakeColumnUtils = snowflakeColumnUtils val formattedValue = formatter.format(record, dummyColumnSchema)
)
val formattedValue = formatter.format(record) // The formatter outputs in a fixed order regardless of input column order:
// 1. AB_RAW_ID
// 2. AB_EXTRACTED_AT
// 3. AB_META
// 4. AB_GENERATION_ID
// 5. AB_LOADED_AT
// 6. DATA (JSON with remaining columns)
val expectedValue = val expectedValue =
createExpected( listOf(
record = record, record[COLUMN_NAME_AB_RAW_ID]!!.toCsvValue(),
columns = columnsMap, record[COLUMN_NAME_AB_EXTRACTED_AT]!!.toCsvValue(),
airbyteColumns = columnsMap.keys.toList(), record[COLUMN_NAME_AB_META]!!.toCsvValue(),
) record[COLUMN_NAME_AB_GENERATION_ID]!!.toCsvValue(),
.toMutableList() record[COLUMN_NAME_AB_LOADED_AT]!!.toCsvValue(),
expectedValue.add( "{\"$columnName\":\"$columnValue\"}"
columnsMap.keys.indexOf(COLUMN_NAME_DATA), )
"{\"$columnName\":\"$columnValue\"}"
)
assertEquals(expectedValue, formattedValue) assertEquals(expectedValue, formattedValue)
} }
} }

View File

@@ -4,66 +4,58 @@
package io.airbyte.integrations.destination.snowflake.write.load package io.airbyte.integrations.destination.snowflake.write.load
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.data.AirbyteValue import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerValue import io.airbyte.cdk.load.data.IntegerValue
import io.airbyte.cdk.load.data.NullValue import io.airbyte.cdk.load.data.NullValue
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.StringValue import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.data.csv.toCsvValue import io.airbyte.cdk.load.data.csv.toCsvValue
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.mockk.every
import io.mockk.mockk
import java.util.AbstractMap
import kotlin.collections.component1
import kotlin.collections.component2
import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
private val AIRBYTE_COLUMN_TYPES_MAP =
linkedMapOf(
COLUMN_NAME_AB_RAW_ID to "VARCHAR(16777216)",
COLUMN_NAME_AB_EXTRACTED_AT to "TIMESTAMP_TZ(9)",
COLUMN_NAME_AB_META to "VARIANT",
COLUMN_NAME_AB_GENERATION_ID to "NUMBER(38,0)",
)
.mapKeys { it.key.toSnowflakeCompatibleName() }
internal class SnowflakeSchemaRecordFormatterTest { internal class SnowflakeSchemaRecordFormatterTest {
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils private fun createColumnSchema(userColumns: Map<String, String>): ColumnSchema {
val finalSchema = linkedMapOf<String, ColumnType>()
val inputToFinalColumnNames = mutableMapOf<String, String>()
val inputSchema = mutableMapOf<String, FieldType>()
@BeforeEach // Add user columns
fun setup() { userColumns.forEach { (name, type) ->
snowflakeColumnUtils = mockk { val finalName = name.toSnowflakeCompatibleName()
every { getFormattedDefaultColumnNames(any()) } returns finalSchema[finalName] = ColumnType(type, true)
AIRBYTE_COLUMN_TYPES_MAP.keys.toList() inputToFinalColumnNames[name] = finalName
inputSchema[name] = FieldType(StringType, nullable = true)
} }
return ColumnSchema(
inputToFinalColumnNames = inputToFinalColumnNames,
finalSchema = finalSchema,
inputSchema = inputSchema
)
} }
@Test @Test
fun testFormatting() { fun testFormatting() {
val columnName = "test-column-name" val columnName = "test-column-name"
val columnValue = "test-column-value" val columnValue = "test-column-value"
val columns = val userColumns = mapOf(columnName to "VARCHAR(16777216)")
(AIRBYTE_COLUMN_TYPES_MAP + linkedMapOf(columnName to "VARCHAR(16777216)")).mapKeys { val columnSchema = createColumnSchema(userColumns)
it.key.toSnowflakeCompatibleName()
}
val record = createRecord(columnName, columnValue) val record = createRecord(columnName, columnValue)
val formatter = val formatter = SnowflakeSchemaRecordFormatter()
SnowflakeSchemaRecordFormatter( val formattedValue = formatter.format(record, columnSchema)
columns = columns as LinkedHashMap<String, String>,
snowflakeColumnUtils = snowflakeColumnUtils
)
val formattedValue = formatter.format(record)
val expectedValue = val expectedValue =
createExpected( createExpected(
record = record, record = record,
columns = columns, columnSchema = columnSchema,
) )
assertEquals(expectedValue, formattedValue) assertEquals(expectedValue, formattedValue)
} }
@@ -72,21 +64,15 @@ internal class SnowflakeSchemaRecordFormatterTest {
fun testFormattingVariant() { fun testFormattingVariant() {
val columnName = "test-column-name" val columnName = "test-column-name"
val columnValue = "{\"test\": \"test-value\"}" val columnValue = "{\"test\": \"test-value\"}"
val columns = val userColumns = mapOf(columnName to "VARIANT")
(AIRBYTE_COLUMN_TYPES_MAP + linkedMapOf(columnName to "VARIANT")).mapKeys { val columnSchema = createColumnSchema(userColumns)
it.key.toSnowflakeCompatibleName()
}
val record = createRecord(columnName, columnValue) val record = createRecord(columnName, columnValue)
val formatter = val formatter = SnowflakeSchemaRecordFormatter()
SnowflakeSchemaRecordFormatter( val formattedValue = formatter.format(record, columnSchema)
columns = columns as LinkedHashMap<String, String>,
snowflakeColumnUtils = snowflakeColumnUtils
)
val formattedValue = formatter.format(record)
val expectedValue = val expectedValue =
createExpected( createExpected(
record = record, record = record,
columns = columns, columnSchema = columnSchema,
) )
assertEquals(expectedValue, formattedValue) assertEquals(expectedValue, formattedValue)
} }
@@ -95,23 +81,16 @@ internal class SnowflakeSchemaRecordFormatterTest {
fun testFormattingMissingColumn() { fun testFormattingMissingColumn() {
val columnName = "test-column-name" val columnName = "test-column-name"
val columnValue = "test-column-value" val columnValue = "test-column-value"
val columns = val userColumns =
AIRBYTE_COLUMN_TYPES_MAP + mapOf(columnName to "VARCHAR(16777216)", "missing-column" to "VARCHAR(16777216)")
linkedMapOf( val columnSchema = createColumnSchema(userColumns)
columnName to "VARCHAR(16777216)",
"missing-column" to "VARCHAR(16777216)"
)
val record = createRecord(columnName, columnValue) val record = createRecord(columnName, columnValue)
val formatter = val formatter = SnowflakeSchemaRecordFormatter()
SnowflakeSchemaRecordFormatter( val formattedValue = formatter.format(record, columnSchema)
columns = columns as LinkedHashMap<String, String>,
snowflakeColumnUtils = snowflakeColumnUtils
)
val formattedValue = formatter.format(record)
val expectedValue = val expectedValue =
createExpected( createExpected(
record = record, record = record,
columns = columns, columnSchema = columnSchema,
filterMissing = false, filterMissing = false,
) )
assertEquals(expectedValue, formattedValue) assertEquals(expectedValue, formattedValue)
@@ -128,16 +107,37 @@ internal class SnowflakeSchemaRecordFormatterTest {
private fun createExpected( private fun createExpected(
record: Map<String, AirbyteValue>, record: Map<String, AirbyteValue>,
columns: Map<String, String>, columnSchema: ColumnSchema,
filterMissing: Boolean = true, filterMissing: Boolean = true,
) = ): List<Any> {
record.entries val columns = columnSchema.finalSchema.keys.toList()
.associate { entry -> entry.key.toSnowflakeCompatibleName() to entry.value } val result = mutableListOf<Any>()
.map { entry -> AbstractMap.SimpleEntry(entry.key, entry.value.toCsvValue()) }
.sortedBy { entry -> // Add meta columns first in the expected order
if (columns.keys.indexOf(entry.key) > -1) columns.keys.indexOf(entry.key) result.add(record[COLUMN_NAME_AB_RAW_ID]?.toCsvValue() ?: "")
else Int.MAX_VALUE result.add(record[COLUMN_NAME_AB_EXTRACTED_AT]?.toCsvValue() ?: "")
result.add(record[COLUMN_NAME_AB_META]?.toCsvValue() ?: "")
result.add(record[COLUMN_NAME_AB_GENERATION_ID]?.toCsvValue() ?: "")
// Add user columns
val userColumns =
columns.filterNot { col ->
listOf(
COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName(),
COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName(),
COLUMN_NAME_AB_META.toSnowflakeCompatibleName(),
COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName()
)
.contains(col)
} }
.filter { (k, _) -> if (filterMissing) columns.contains(k) else true }
.map { it.value } userColumns.forEach { columnName ->
val value = record[columnName] ?: if (!filterMissing) NullValue else null
if (value != null || !filterMissing) {
result.add(value?.toCsvValue() ?: "")
}
}
return result
}
} }

View File

@@ -1,45 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.write.transform
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.mockk.every
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
internal class SnowflakeColumnNameMapperTest {
@Test
fun testGetMappedColumnName() {
val columnName = "tést-column-name"
val expectedName = "test-column-name"
val stream = mockk<DestinationStream>()
val tableCatalog = mockk<TableCatalog>()
val snowflakeConfiguration = mockk<SnowflakeConfiguration>(relaxed = true)
// Configure the mock to return the expected mapped column name
every { tableCatalog.getMappedColumnName(stream, columnName) } returns expectedName
val mapper = SnowflakeColumnNameMapper(tableCatalog, snowflakeConfiguration)
val result = mapper.getMappedColumnName(stream = stream, columnName = columnName)
assertEquals(expectedName, result)
}
@Test
fun testGetMappedColumnNameRawFormat() {
val columnName = "tést-column-name"
val stream = mockk<DestinationStream>()
val tableCatalog = mockk<TableCatalog>()
val snowflakeConfiguration =
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns true }
val mapper = SnowflakeColumnNameMapper(tableCatalog, snowflakeConfiguration)
val result = mapper.getMappedColumnName(stream = stream, columnName = columnName)
assertEquals(columnName, result)
}
}

View File

@@ -260,6 +260,7 @@ desired namespace.
| Version | Date | Pull Request | Subject | | Version | Date | Pull Request | Subject |
|:----------------|:-----------|:--------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| |:----------------|:-----------|:--------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 4.0.32-rc.1 | 2025-12-18 | [70903](https://github.com/airbytehq/airbyte/pull/70903) | Upgrade to CDK 0.1.91; internal refactoring |
| 4.0.31 | 2025-12-08 | [70442](https://github.com/airbytehq/airbyte/pull/70442) | Write VARIANT values correctly when underlying Airbyte type is a `union` | | 4.0.31 | 2025-12-08 | [70442](https://github.com/airbytehq/airbyte/pull/70442) | Write VARIANT values correctly when underlying Airbyte type is a `union` |
| 4.0.30 | 2025-11-24 | [69842](https://github.com/airbytehq/airbyte/pull/69842) | Update documentation about numeric value handling | | 4.0.30 | 2025-11-24 | [69842](https://github.com/airbytehq/airbyte/pull/69842) | Update documentation about numeric value handling |
| 4.0.29 | 2025-11-14 | [69342](https://github.com/airbytehq/airbyte/pull/69342) | Truncate NumberValues and IntegerValues with excessive precision instead of nullifying them | | 4.0.29 | 2025-11-14 | [69342](https://github.com/airbytehq/airbyte/pull/69342) | Truncate NumberValues and IntegerValues with excessive precision instead of nullifying them |