1
0
mirror of synced 2025-12-19 10:00:34 -05:00

Snowflake maps its schema once at the start. (#70903)

This commit is contained in:
Ryan Br...
2025-12-18 15:14:01 -08:00
committed by GitHub
parent 226af71657
commit 272f243e44
47 changed files with 2047 additions and 2557 deletions

View File

@@ -1,3 +1,3 @@
testExecutionConcurrency=-1
cdkVersion=0.1.82
cdkVersion=0.1.91
JunitMethodExecutionTimeout=10m

View File

@@ -6,7 +6,7 @@ data:
connectorSubtype: database
connectorType: destination
definitionId: 424892c4-daac-4491-b35d-c6688ba547ba
dockerImageTag: 4.0.31
dockerImageTag: 4.0.32-rc.1
dockerRepository: airbyte/destination-snowflake
documentationUrl: https://docs.airbyte.com/integrations/destinations/snowflake
githubIssueLabel: destination-snowflake
@@ -31,6 +31,8 @@ data:
enabled: true
releaseStage: generally_available
releases:
rolloutConfiguration:
enableProgressiveRollout: true
breakingChanges:
2.0.0:
message: Remove GCS/S3 loading method support.

View File

@@ -11,15 +11,18 @@ import io.airbyte.cdk.load.check.CheckOperationV2
import io.airbyte.cdk.load.check.DestinationCheckerV2
import io.airbyte.cdk.load.config.DataChannelMedium
import io.airbyte.cdk.load.dataflow.config.AggregatePublishingConfig
import io.airbyte.cdk.load.orchestration.db.DefaultTempTableNameGenerator
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator
import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator
import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.cdk.output.OutputConsumer
import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory
import io.airbyte.integrations.destination.snowflake.spec.UsernamePasswordAuthConfiguration
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRawRecordFormatter
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter
import io.micronaut.context.annotation.Factory
import io.micronaut.context.annotation.Primary
import io.micronaut.context.annotation.Requires
@@ -204,6 +207,17 @@ class SnowflakeBeanFactory {
outputConsumer: OutputConsumer,
) = CheckOperationV2(destinationChecker, outputConsumer)
@Singleton
fun snowflakeRecordFormatter(
snowflakeConfiguration: SnowflakeConfiguration
): SnowflakeRecordFormatter {
return if (snowflakeConfiguration.legacyRawTablesOnly) {
SnowflakeRawRecordFormatter()
} else {
SnowflakeSchemaRecordFormatter()
}
}
@Singleton
fun aggregatePublishingConfig(dataChannelMedium: DataChannelMedium): AggregatePublishingConfig {
// NOT speed mode

View File

@@ -13,13 +13,17 @@ import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.schema.model.TableNames
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TableName
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter
import jakarta.inject.Singleton
import java.time.OffsetDateTime
import java.util.UUID
@@ -31,7 +35,7 @@ internal const val CHECK_COLUMN_NAME = "test_key"
class SnowflakeChecker(
private val snowflakeAirbyteClient: SnowflakeAirbyteClient,
private val snowflakeConfiguration: SnowflakeConfiguration,
private val snowflakeColumnUtils: SnowflakeColumnUtils,
private val columnManager: SnowflakeColumnManager,
) : DestinationCheckerV2 {
override fun check() {
@@ -46,11 +50,40 @@ class SnowflakeChecker(
Meta.AirbyteMetaFields.GENERATION_ID.fieldName to AirbyteValue.from(0),
CHECK_COLUMN_NAME.toSnowflakeCompatibleName() to AirbyteValue.from("test-value")
)
val outputSchema = snowflakeConfiguration.schema.toSnowflakeCompatibleName()
val outputSchema =
if (snowflakeConfiguration.legacyRawTablesOnly) {
snowflakeConfiguration.schema
} else {
snowflakeConfiguration.schema.toSnowflakeCompatibleName()
}
val tableName =
"_airbyte_connection_test_${
UUID.randomUUID().toString().replace("-".toRegex(), "")}".toSnowflakeCompatibleName()
val qualifiedTableName = TableName(namespace = outputSchema, name = tableName)
val tableSchema =
StreamTableSchema(
tableNames =
TableNames(
finalTableName = qualifiedTableName,
tempTableName = qualifiedTableName
),
columnSchema =
ColumnSchema(
inputToFinalColumnNames =
mapOf(
CHECK_COLUMN_NAME to CHECK_COLUMN_NAME.toSnowflakeCompatibleName()
),
finalSchema =
mapOf(
CHECK_COLUMN_NAME.toSnowflakeCompatibleName() to
io.airbyte.cdk.load.component.ColumnType("VARCHAR", false)
),
inputSchema =
mapOf(CHECK_COLUMN_NAME to FieldType(StringType, nullable = false))
),
importType = Append
)
val destinationStream =
DestinationStream(
unmappedNamespace = outputSchema,
@@ -63,7 +96,8 @@ class SnowflakeChecker(
generationId = 0L,
minimumGenerationId = 0L,
syncId = 0L,
namespaceMapper = NamespaceMapper()
namespaceMapper = NamespaceMapper(),
tableSchema = tableSchema
)
runBlocking {
try {
@@ -75,14 +109,14 @@ class SnowflakeChecker(
replace = true,
)
val columns = snowflakeAirbyteClient.describeTable(qualifiedTableName)
val snowflakeInsertBuffer =
SnowflakeInsertBuffer(
tableName = qualifiedTableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils,
columnSchema = tableSchema.columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = SnowflakeSchemaRecordFormatter(),
)
snowflakeInsertBuffer.accumulate(data)

View File

@@ -13,18 +13,16 @@ import io.airbyte.cdk.load.component.TableColumns
import io.airbyte.cdk.load.component.TableOperationsClient
import io.airbyte.cdk.load.component.TableSchema
import io.airbyte.cdk.load.component.TableSchemaEvolutionClient
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.util.deserializeToNode
import io.airbyte.integrations.destination.snowflake.db.ColumnDefinition
import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS
import io.airbyte.integrations.destination.snowflake.sql.NOT_NULL
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
import io.airbyte.integrations.destination.snowflake.sql.andLog
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
import io.github.oshai.kotlinlogging.KotlinLogging
import jakarta.inject.Singleton
import java.sql.ResultSet
@@ -41,13 +39,10 @@ private val log = KotlinLogging.logger {}
class SnowflakeAirbyteClient(
private val dataSource: DataSource,
private val sqlGenerator: SnowflakeDirectLoadSqlGenerator,
private val snowflakeColumnUtils: SnowflakeColumnUtils,
private val snowflakeConfiguration: SnowflakeConfiguration,
private val columnManager: SnowflakeColumnManager,
) : TableOperationsClient, TableSchemaEvolutionClient {
private val airbyteColumnNames =
snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet()
override suspend fun countTable(tableName: TableName): Long? =
try {
dataSource.connection.use { connection ->
@@ -126,7 +121,7 @@ class SnowflakeAirbyteClient(
columnNameMapping: ColumnNameMapping,
replace: Boolean
) {
execute(sqlGenerator.createTable(stream, tableName, columnNameMapping, replace))
execute(sqlGenerator.createTable(tableName, stream.tableSchema, replace))
execute(sqlGenerator.createSnowflakeStage(tableName))
}
@@ -163,7 +158,15 @@ class SnowflakeAirbyteClient(
sourceTableName: TableName,
targetTableName: TableName
) {
execute(sqlGenerator.copyTable(columnNameMapping, sourceTableName, targetTableName))
// Get all column names from the mapping (both meta columns and user columns)
val columnNames = buildSet {
// Add Airbyte meta columns (using uppercase constants)
addAll(columnManager.getMetaColumnNames())
// Add user columns from mapping
addAll(columnNameMapping.values)
}
execute(sqlGenerator.copyTable(columnNames, sourceTableName, targetTableName))
}
override suspend fun upsertTable(
@@ -172,9 +175,7 @@ class SnowflakeAirbyteClient(
sourceTableName: TableName,
targetTableName: TableName
) {
execute(
sqlGenerator.upsertTable(stream, columnNameMapping, sourceTableName, targetTableName)
)
execute(sqlGenerator.upsertTable(stream.tableSchema, sourceTableName, targetTableName))
}
override suspend fun dropTable(tableName: TableName) {
@@ -206,7 +207,7 @@ class SnowflakeAirbyteClient(
stream: DestinationStream,
columnNameMapping: ColumnNameMapping
): TableSchema {
return TableSchema(getColumnsFromStream(stream, columnNameMapping))
return TableSchema(stream.tableSchema.columnSchema.finalSchema)
}
override suspend fun applyChangeset(
@@ -253,7 +254,7 @@ class SnowflakeAirbyteClient(
val columnName = escapeJsonIdentifier(rs.getString("name"))
// Filter out airbyte columns
if (airbyteColumnNames.contains(columnName)) {
if (columnManager.getMetaColumnNames().contains(columnName)) {
continue
}
val dataType = rs.getString("type").takeWhile { char -> char != '(' }
@@ -271,49 +272,6 @@ class SnowflakeAirbyteClient(
}
}
internal fun getColumnsFromStream(
stream: DestinationStream,
columnNameMapping: ColumnNameMapping
): Map<String, ColumnType> =
snowflakeColumnUtils
.columnsAndTypes(stream.schema.asColumns(), columnNameMapping)
.filter { column -> column.columnName !in airbyteColumnNames }
.associate { column ->
// columnsAndTypes returns types as either `FOO` or `FOO NOT NULL`.
// so check for that suffix.
val nullable = !column.columnType.endsWith(NOT_NULL)
val type =
column.columnType
.takeWhile { char ->
// This is to remove any precision parts of the dialect type
char != '('
}
.removeSuffix(NOT_NULL)
.trim()
column.columnName to ColumnType(type, nullable)
}
internal fun generateSchemaChanges(
columnsInDb: Set<ColumnDefinition>,
columnsInStream: Set<ColumnDefinition>
): Triple<Set<ColumnDefinition>, Set<ColumnDefinition>, Set<ColumnDefinition>> {
val addedColumns =
columnsInStream.filter { it.name !in columnsInDb.map { col -> col.name } }.toSet()
val deletedColumns =
columnsInDb.filter { it.name !in columnsInStream.map { col -> col.name } }.toSet()
val commonColumns =
columnsInStream.filter { it.name in columnsInDb.map { col -> col.name } }.toSet()
val modifiedColumns =
commonColumns
.filter {
val dbType = columnsInDb.find { column -> it.name == column.name }?.type
it.type != dbType
}
.toSet()
return Triple(addedColumns, deletedColumns, modifiedColumns)
}
override suspend fun getGenerationId(tableName: TableName): Long =
try {
dataSource.connection.use { connection ->
@@ -326,7 +284,7 @@ class SnowflakeAirbyteClient(
* format. In order to make sure these strings will match any column names
* that we have formatted in-memory, re-apply the escaping.
*/
resultSet.getLong(snowflakeColumnUtils.getGenerationIdColumnName())
resultSet.getLong(columnManager.getGenerationIdColumnName())
} else {
log.warn {
"No generation ID found for table ${tableName.toPrettyString()}, returning 0"
@@ -351,8 +309,8 @@ class SnowflakeAirbyteClient(
execute(sqlGenerator.putInStage(tableName, tempFilePath))
}
fun copyFromStage(tableName: TableName, filename: String) {
execute(sqlGenerator.copyFromStage(tableName, filename))
fun copyFromStage(tableName: TableName, filename: String, columnNames: List<String>) {
execute(sqlGenerator.copyFromStage(tableName, filename, columnNames))
}
fun describeTable(tableName: TableName): LinkedHashMap<String, String> =

View File

@@ -4,47 +4,41 @@
package io.airbyte.integrations.destination.snowflake.dataflow
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.dataflow.aggregate.Aggregate
import io.airbyte.cdk.load.dataflow.aggregate.AggregateFactory
import io.airbyte.cdk.load.dataflow.aggregate.StoreKey
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.write.StreamStateStore
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
import io.micronaut.cache.annotation.CacheConfig
import io.micronaut.cache.annotation.Cacheable
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
import jakarta.inject.Singleton
@Singleton
@CacheConfig("table-columns")
// class has to be open to make the cache stuff work
open class SnowflakeAggregateFactory(
class SnowflakeAggregateFactory(
private val snowflakeClient: SnowflakeAirbyteClient,
private val streamStateStore: StreamStateStore<DirectLoadTableExecutionConfig>,
private val snowflakeConfiguration: SnowflakeConfiguration,
private val snowflakeColumnUtils: SnowflakeColumnUtils,
private val catalog: DestinationCatalog,
private val columnManager: SnowflakeColumnManager,
private val snowflakeRecordFormatter: SnowflakeRecordFormatter,
) : AggregateFactory {
override fun create(key: StoreKey): Aggregate {
val stream = catalog.getStream(key)
val tableName = streamStateStore.get(key)!!.tableName
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
columns = getTableColumns(tableName),
snowflakeClient = snowflakeClient,
snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils,
columnSchema = stream.tableSchema.columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
)
return SnowflakeAggregate(buffer = buffer)
}
// We assume that a table isn't getting altered _during_ a sync.
// This allows us to only SHOW COLUMNS once per table per sync,
// rather than refetching it on every aggregate.
@Cacheable
// function has to be open to make caching work
internal open fun getTableColumns(tableName: TableName) =
snowflakeClient.describeTable(tableName)
}

View File

@@ -1,13 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
/**
* Jdbc destination column definition representation
*
* @param name
* @param type
*/
data class ColumnDefinition(val name: String, val type: String)

View File

@@ -4,17 +4,17 @@
package io.airbyte.integrations.destination.snowflake.db
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.component.TableOperationsClient
import io.airbyte.cdk.load.orchestration.db.BaseDirectLoadInitialStatusGatherer
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator
import io.airbyte.cdk.load.table.BaseDirectLoadInitialStatusGatherer
import jakarta.inject.Singleton
@Singleton
class SnowflakeDirectLoadDatabaseInitialStatusGatherer(
tableOperationsClient: TableOperationsClient,
tempTableNameGenerator: TempTableNameGenerator,
catalog: DestinationCatalog,
) :
BaseDirectLoadInitialStatusGatherer(
tableOperationsClient,
tempTableNameGenerator,
catalog,
)

View File

@@ -1,95 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
import io.airbyte.cdk.ConfigErrorException
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.orchestration.db.ColumnNameGenerator
import io.airbyte.cdk.load.orchestration.db.FinalTableNameGenerator
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TypingDedupingUtil
import io.airbyte.cdk.load.table.TableName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.QUOTE
import jakarta.inject.Singleton
@Singleton
class SnowflakeFinalTableNameGenerator(private val config: SnowflakeConfiguration) :
FinalTableNameGenerator {
override fun getTableName(streamDescriptor: DestinationStream.Descriptor): TableName {
val namespace = streamDescriptor.namespace ?: config.schema
return if (!config.legacyRawTablesOnly) {
TableName(
namespace = namespace.toSnowflakeCompatibleName(),
name = streamDescriptor.name.toSnowflakeCompatibleName(),
)
} else {
TableName(
namespace = config.internalTableSchema,
name =
TypingDedupingUtil.concatenateRawTableName(
namespace = escapeJsonIdentifier(namespace),
name = escapeJsonIdentifier(streamDescriptor.name),
),
)
}
}
}
@Singleton
class SnowflakeColumnNameGenerator(private val config: SnowflakeConfiguration) :
ColumnNameGenerator {
override fun getColumnName(column: String): ColumnNameGenerator.ColumnName {
return if (!config.legacyRawTablesOnly) {
ColumnNameGenerator.ColumnName(
column.toSnowflakeCompatibleName(),
column.toSnowflakeCompatibleName(),
)
} else {
ColumnNameGenerator.ColumnName(
column,
column,
)
}
}
}
/**
* Escapes double-quotes in a JSON identifier by doubling them. This shit is legacy -- I don't know
* why this would be necessary but no harm in keeping it so I am keeping it.
*
* @return The escaped identifier.
*/
fun escapeJsonIdentifier(identifier: String): String {
// Note that we don't need to escape backslashes here!
// The only special character in an identifier is the double-quote, which needs to be
// doubled.
return identifier.replace(QUOTE, "$QUOTE$QUOTE")
}
/**
* Transforms a string to be compatible with Snowflake table and column names.
*
* @return The transformed string suitable for Snowflake identifiers.
*/
fun String.toSnowflakeCompatibleName(): String {
var identifier = this
// Handle empty strings
if (identifier.isEmpty()) {
throw ConfigErrorException("Empty string is invalid identifier")
}
// Snowflake scripting language does something weird when the `${` bigram shows up in the
// script so replace these with something else.
// For completeness, if we trigger this, also replace closing curly braces with underscores.
if (identifier.contains("\${")) {
identifier = identifier.replace("$", "_").replace("{", "_").replace("}", "_")
}
// Escape double quotes
identifier = escapeJsonIdentifier(identifier)
return identifier.uppercase()
}

View File

@@ -0,0 +1,116 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.schema
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_EXTRACTED_AT
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_GENERATION_ID
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_META
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_RAW_ID
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDataType
import jakarta.inject.Singleton
/**
* Manages column names and ordering for Snowflake tables based on whether legacy raw tables mode is
* enabled.
*
* TODO: We should add meta column munging and raw table support to the CDK, so this extra layer of
* management shouldn't be necessary.
*/
@Singleton
class SnowflakeColumnManager(
private val config: SnowflakeConfiguration,
) {
/**
* Get the list of column names for a table in the order they should appear in the CSV file and
* COPY INTO statement.
*
* Warning: MUST match the order defined in SnowflakeRecordFormatter
*
* @param columnSchema The schema containing column information (ignored in raw mode)
* @return List of column names in the correct order
*/
fun getTableColumnNames(columnSchema: ColumnSchema): List<String> {
return buildList {
addAll(getMetaColumnNames())
addAll(columnSchema.finalSchema.keys)
}
}
/**
* Get the list of Airbyte meta column names. In schema mode, these are uppercase. In raw mode,
* they are lowercase and included loaded_at
*
* @return Set of meta column names
*/
fun getMetaColumnNames(): Set<String> =
if (config.legacyRawTablesOnly) {
Constants.rawModeMetaColNames
} else {
Constants.schemaModeMetaColNames
}
/**
* Get the Airbyte meta columns as a map of column name to ColumnType. This provides both the
* column names and their types for table creation.
*
* @param columnSchema The user column schema (used to check for CDC columns)
* @return Map of meta column names to their types
*/
fun getMetaColumns(): LinkedHashMap<String, ColumnType> {
return if (config.legacyRawTablesOnly) {
Constants.rawModeMetaColumns
} else {
Constants.schemaModeMetaColumns
}
}
fun getGenerationIdColumnName(): String {
return if (config.legacyRawTablesOnly) {
Meta.COLUMN_NAME_AB_GENERATION_ID
} else {
SNOWFLAKE_AB_GENERATION_ID
}
}
object Constants {
val rawModeMetaColumns =
linkedMapOf(
Meta.COLUMN_NAME_AB_RAW_ID to ColumnType(SnowflakeDataType.VARCHAR.typeName, false),
Meta.COLUMN_NAME_AB_EXTRACTED_AT to
ColumnType(
SnowflakeDataType.TIMESTAMP_TZ.typeName,
false,
),
Meta.COLUMN_NAME_AB_META to ColumnType(SnowflakeDataType.VARIANT.typeName, false),
Meta.COLUMN_NAME_AB_GENERATION_ID to
ColumnType(
SnowflakeDataType.NUMBER.typeName,
true,
),
Meta.COLUMN_NAME_AB_LOADED_AT to
ColumnType(
SnowflakeDataType.TIMESTAMP_TZ.typeName,
true,
),
)
val schemaModeMetaColumns =
linkedMapOf(
SNOWFLAKE_AB_RAW_ID to ColumnType(SnowflakeDataType.VARCHAR.typeName, false),
SNOWFLAKE_AB_EXTRACTED_AT to
ColumnType(SnowflakeDataType.TIMESTAMP_TZ.typeName, false),
SNOWFLAKE_AB_META to ColumnType(SnowflakeDataType.VARIANT.typeName, false),
SNOWFLAKE_AB_GENERATION_ID to ColumnType(SnowflakeDataType.NUMBER.typeName, true),
)
val rawModeMetaColNames: Set<String> = rawModeMetaColumns.keys
val schemaModeMetaColNames: Set<String> = schemaModeMetaColumns.keys
}
}

View File

@@ -0,0 +1,34 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.schema
import io.airbyte.cdk.ConfigErrorException
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
/**
* Transforms a string to be compatible with Snowflake table and column names.
*
* @return The transformed string suitable for Snowflake identifiers.
*/
fun String.toSnowflakeCompatibleName(): String {
var identifier = this
// Handle empty strings
if (identifier.isEmpty()) {
throw ConfigErrorException("Empty string is invalid identifier")
}
// Snowflake scripting language does something weird when the `${` bigram shows up in the
// script so replace these with something else.
// For completeness, if we trigger this, also replace closing curly braces with underscores.
if (identifier.contains("\${")) {
identifier = identifier.replace("$", "_").replace("{", "_").replace("}", "_")
}
// Escape double quotes
identifier = escapeJsonIdentifier(identifier)
return identifier.uppercase()
}

View File

@@ -0,0 +1,121 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.schema
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.data.ArrayType
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
import io.airbyte.cdk.load.data.BooleanType
import io.airbyte.cdk.load.data.DateType
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerType
import io.airbyte.cdk.load.data.NumberType
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
import io.airbyte.cdk.load.data.UnionType
import io.airbyte.cdk.load.data.UnknownType
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.TableSchemaMapper
import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.cdk.load.table.TypingDedupingUtil
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDataType
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
import jakarta.inject.Singleton
@Singleton
class SnowflakeTableSchemaMapper(
private val config: SnowflakeConfiguration,
private val tempTableNameGenerator: TempTableNameGenerator,
) : TableSchemaMapper {
override fun toFinalTableName(desc: DestinationStream.Descriptor): TableName {
val namespace = desc.namespace ?: config.schema
return if (!config.legacyRawTablesOnly) {
TableName(
namespace = namespace.toSnowflakeCompatibleName(),
name = desc.name.toSnowflakeCompatibleName(),
)
} else {
TableName(
namespace = config.internalTableSchema,
name =
TypingDedupingUtil.concatenateRawTableName(
namespace = escapeJsonIdentifier(namespace),
name = escapeJsonIdentifier(desc.name),
),
)
}
}
override fun toTempTableName(tableName: TableName): TableName {
return tempTableNameGenerator.generate(tableName)
}
override fun toColumnName(name: String): String {
return if (!config.legacyRawTablesOnly) {
name.toSnowflakeCompatibleName()
} else {
// In legacy mode, column names are not transformed
name
}
}
override fun toColumnType(fieldType: FieldType): ColumnType {
val snowflakeType =
when (fieldType.type) {
// Simple types
BooleanType -> SnowflakeDataType.BOOLEAN.typeName
IntegerType -> SnowflakeDataType.NUMBER.typeName
NumberType -> SnowflakeDataType.FLOAT.typeName
StringType -> SnowflakeDataType.VARCHAR.typeName
// Temporal types
DateType -> SnowflakeDataType.DATE.typeName
TimeTypeWithTimezone -> SnowflakeDataType.VARCHAR.typeName
TimeTypeWithoutTimezone -> SnowflakeDataType.TIME.typeName
TimestampTypeWithTimezone -> SnowflakeDataType.TIMESTAMP_TZ.typeName
TimestampTypeWithoutTimezone -> SnowflakeDataType.TIMESTAMP_NTZ.typeName
// Semistructured types
is ArrayType,
ArrayTypeWithoutSchema -> SnowflakeDataType.ARRAY.typeName
is ObjectType,
ObjectTypeWithEmptySchema,
ObjectTypeWithoutSchema -> SnowflakeDataType.OBJECT.typeName
is UnionType -> SnowflakeDataType.VARIANT.typeName
is UnknownType -> SnowflakeDataType.VARIANT.typeName
}
return ColumnType(snowflakeType, fieldType.nullable)
}
override fun toFinalSchema(tableSchema: StreamTableSchema): StreamTableSchema {
if (!config.legacyRawTablesOnly) {
return tableSchema
}
return StreamTableSchema(
tableNames = tableSchema.tableNames,
columnSchema =
tableSchema.columnSchema.copy(
finalSchema =
mapOf(
Meta.COLUMN_NAME_DATA to
ColumnType(SnowflakeDataType.OBJECT.typeName, false)
)
),
importType = tableSchema.importType,
)
}
}

View File

@@ -1,201 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
import com.google.common.annotations.VisibleForTesting
import io.airbyte.cdk.load.data.AirbyteType
import io.airbyte.cdk.load.data.ArrayType
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
import io.airbyte.cdk.load.data.BooleanType
import io.airbyte.cdk.load.data.DateType
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerType
import io.airbyte.cdk.load.data.NumberType
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
import io.airbyte.cdk.load.data.UnionType
import io.airbyte.cdk.load.data.UnknownType
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.integrations.destination.snowflake.db.SnowflakeColumnNameGenerator
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import jakarta.inject.Singleton
import kotlin.collections.component1
import kotlin.collections.component2
import kotlin.collections.joinToString
import kotlin.collections.map
import kotlin.collections.plus
internal const val NOT_NULL = "NOT NULL"
internal val DEFAULT_COLUMNS =
listOf(
ColumnAndType(
columnName = COLUMN_NAME_AB_RAW_ID,
columnType = "${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL"
),
ColumnAndType(
columnName = COLUMN_NAME_AB_EXTRACTED_AT,
columnType = "${SnowflakeDataType.TIMESTAMP_TZ.typeName} $NOT_NULL"
),
ColumnAndType(
columnName = COLUMN_NAME_AB_META,
columnType = "${SnowflakeDataType.VARIANT.typeName} $NOT_NULL"
),
ColumnAndType(
columnName = COLUMN_NAME_AB_GENERATION_ID,
columnType = SnowflakeDataType.NUMBER.typeName
),
)
internal val RAW_DATA_COLUMN =
ColumnAndType(
columnName = COLUMN_NAME_DATA,
columnType = "${SnowflakeDataType.VARIANT.typeName} $NOT_NULL"
)
internal val RAW_COLUMNS =
listOf(
ColumnAndType(
columnName = COLUMN_NAME_AB_LOADED_AT,
columnType = SnowflakeDataType.TIMESTAMP_TZ.typeName
),
RAW_DATA_COLUMN
)
@Singleton
class SnowflakeColumnUtils(
private val snowflakeConfiguration: SnowflakeConfiguration,
private val snowflakeColumnNameGenerator: SnowflakeColumnNameGenerator,
) {
@VisibleForTesting
internal fun defaultColumns(): List<ColumnAndType> =
if (snowflakeConfiguration.legacyRawTablesOnly) {
DEFAULT_COLUMNS + RAW_COLUMNS
} else {
DEFAULT_COLUMNS
}
internal fun formattedDefaultColumns(): List<ColumnAndType> =
defaultColumns().map {
ColumnAndType(
columnName = formatColumnName(it.columnName, false),
columnType = it.columnType,
)
}
fun getGenerationIdColumnName(): String {
return if (snowflakeConfiguration.legacyRawTablesOnly) {
COLUMN_NAME_AB_GENERATION_ID
} else {
COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName()
}
}
fun getColumnNames(columnNameMapping: ColumnNameMapping): String =
if (snowflakeConfiguration.legacyRawTablesOnly) {
getFormattedDefaultColumnNames(true).joinToString(",")
} else {
(getFormattedDefaultColumnNames(true) +
columnNameMapping.map { (_, actualName) -> actualName.quote() })
.joinToString(",")
}
fun getFormattedDefaultColumnNames(quote: Boolean = false): List<String> =
defaultColumns().map { formatColumnName(it.columnName, quote) }
fun getFormattedColumnNames(
columns: Map<String, FieldType>,
columnNameMapping: ColumnNameMapping,
quote: Boolean = true,
): List<String> =
if (snowflakeConfiguration.legacyRawTablesOnly) {
getFormattedDefaultColumnNames(quote)
} else {
getFormattedDefaultColumnNames(quote) +
columns.map { (fieldName, _) ->
val columnName = columnNameMapping[fieldName] ?: fieldName
if (quote) columnName.quote() else columnName
}
}
fun columnsAndTypes(
columns: Map<String, FieldType>,
columnNameMapping: ColumnNameMapping
): List<ColumnAndType> =
if (snowflakeConfiguration.legacyRawTablesOnly) {
formattedDefaultColumns()
} else {
formattedDefaultColumns() +
columns.map { (fieldName, type) ->
val columnName = columnNameMapping[fieldName] ?: fieldName
val typeName = toDialectType(type.type)
ColumnAndType(
columnName = columnName,
columnType = if (type.nullable) typeName else "$typeName $NOT_NULL",
)
}
}
fun formatColumnName(
columnName: String,
quote: Boolean = true,
): String {
val formattedColumnName =
if (columnName == COLUMN_NAME_DATA) columnName
else snowflakeColumnNameGenerator.getColumnName(columnName).displayName
return if (quote) formattedColumnName.quote() else formattedColumnName
}
fun toDialectType(type: AirbyteType): String =
when (type) {
// Simple types
BooleanType -> SnowflakeDataType.BOOLEAN.typeName
IntegerType -> SnowflakeDataType.NUMBER.typeName
NumberType -> SnowflakeDataType.FLOAT.typeName
StringType -> SnowflakeDataType.VARCHAR.typeName
// Temporal types
DateType -> SnowflakeDataType.DATE.typeName
TimeTypeWithTimezone -> SnowflakeDataType.VARCHAR.typeName
TimeTypeWithoutTimezone -> SnowflakeDataType.TIME.typeName
TimestampTypeWithTimezone -> SnowflakeDataType.TIMESTAMP_TZ.typeName
TimestampTypeWithoutTimezone -> SnowflakeDataType.TIMESTAMP_NTZ.typeName
// Semistructured types
is ArrayType,
ArrayTypeWithoutSchema -> SnowflakeDataType.ARRAY.typeName
is ObjectType,
ObjectTypeWithEmptySchema,
ObjectTypeWithoutSchema -> SnowflakeDataType.OBJECT.typeName
is UnionType -> SnowflakeDataType.VARIANT.typeName
is UnknownType -> SnowflakeDataType.VARIANT.typeName
}
}
data class ColumnAndType(val columnName: String, val columnType: String) {
override fun toString(): String {
return "${columnName.quote()} $columnType"
}
}
/**
* Surrounds the string instance with double quotation marks (e.g. "some string" -> "\"some
* string\"").
*/
fun String.quote() = "$QUOTE$this$QUOTE"

View File

@@ -10,7 +10,7 @@ package io.airbyte.integrations.destination.snowflake.sql
*/
enum class SnowflakeDataType(val typeName: String) {
// Numeric types
NUMBER("NUMBER(38,0)"),
NUMBER("NUMBER"),
FLOAT("FLOAT"),
// String & binary types

View File

@@ -4,16 +4,19 @@
package io.airbyte.integrations.destination.snowflake.sql
import io.airbyte.cdk.load.command.Dedupe
import io.airbyte.cdk.load.command.DestinationStream
import com.google.common.annotations.VisibleForTesting
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.component.ColumnTypeChange
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.util.UUIDGenerator
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.write.load.CSV_FIELD_SEPARATOR
@@ -22,6 +25,14 @@ import io.github.oshai.kotlinlogging.KotlinLogging
import jakarta.inject.Singleton
internal const val COUNT_TOTAL_ALIAS = "TOTAL"
internal const val NOT_NULL = "NOT NULL"
// Snowflake-compatible (uppercase) versions of the Airbyte meta column names
internal val SNOWFLAKE_AB_RAW_ID = COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName()
internal val SNOWFLAKE_AB_EXTRACTED_AT = COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()
internal val SNOWFLAKE_AB_META = COLUMN_NAME_AB_META.toSnowflakeCompatibleName()
internal val SNOWFLAKE_AB_GENERATION_ID = COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName()
internal val SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN = CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()
private val log = KotlinLogging.logger {}
@@ -36,80 +47,91 @@ fun String.andLog(): String {
@Singleton
class SnowflakeDirectLoadSqlGenerator(
private val columnUtils: SnowflakeColumnUtils,
private val uuidGenerator: UUIDGenerator,
private val snowflakeConfiguration: SnowflakeConfiguration,
private val snowflakeSqlNameUtils: SnowflakeSqlNameUtils,
private val config: SnowflakeConfiguration,
private val columnManager: SnowflakeColumnManager,
) {
fun countTable(tableName: TableName): String {
return "SELECT COUNT(*) AS $COUNT_TOTAL_ALIAS FROM ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog()
return "SELECT COUNT(*) AS $COUNT_TOTAL_ALIAS FROM ${fullyQualifiedName(tableName)}".andLog()
}
fun createNamespace(namespace: String): String {
return "CREATE SCHEMA IF NOT EXISTS ${snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)}".andLog()
return "CREATE SCHEMA IF NOT EXISTS ${fullyQualifiedNamespace(namespace)}".andLog()
}
fun createTable(
stream: DestinationStream,
tableName: TableName,
columnNameMapping: ColumnNameMapping,
tableSchema: StreamTableSchema,
replace: Boolean
): String {
val finalSchema = tableSchema.columnSchema.finalSchema
val metaColumns = columnManager.getMetaColumns()
// Build column declarations from the meta columns and user schema
val columnDeclarations =
columnUtils
.columnsAndTypes(stream.schema.asColumns(), columnNameMapping)
.joinToString(",\n")
buildList {
// Add Airbyte meta columns from the column manager
metaColumns.forEach { (columnName, columnType) ->
val nullability = if (columnType.nullable) "" else " NOT NULL"
add("${columnName.quote()} ${columnType.type}$nullability")
}
// Add user columns from the munged schema
finalSchema.forEach { (columnName, columnType) ->
val nullability = if (columnType.nullable) "" else " NOT NULL"
add("${columnName.quote()} ${columnType.type}$nullability")
}
}
.joinToString(",\n ")
// Snowflake supports CREATE OR REPLACE TABLE, which is simpler than drop+recreate
val createOrReplace = if (replace) "CREATE OR REPLACE" else "CREATE"
val createTableStatement =
"""
$createOrReplace TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)} (
$columnDeclarations
)
""".trimIndent()
|$createOrReplace TABLE ${fullyQualifiedName(tableName)} (
| $columnDeclarations
|)
""".trimMargin() // Something was tripping up trimIndent so we opt for trimMargin
return createTableStatement.andLog()
}
fun showColumns(tableName: TableName): String =
"SHOW COLUMNS IN TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog()
"SHOW COLUMNS IN TABLE ${fullyQualifiedName(tableName)}".andLog()
fun copyTable(
columnNameMapping: ColumnNameMapping,
columnNames: Set<String>,
sourceTableName: TableName,
targetTableName: TableName
): String {
val columnNames = columnUtils.getColumnNames(columnNameMapping)
val columnList = columnNames.joinToString(", ") { it.quote() }
return """
INSERT INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)}
INSERT INTO ${fullyQualifiedName(targetTableName)}
(
$columnNames
$columnList
)
SELECT
$columnNames
FROM ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)}
$columnList
FROM ${fullyQualifiedName(sourceTableName)}
"""
.trimIndent()
.andLog()
}
fun upsertTable(
stream: DestinationStream,
columnNameMapping: ColumnNameMapping,
tableSchema: StreamTableSchema,
sourceTableName: TableName,
targetTableName: TableName
): String {
val importType = stream.importType as Dedupe
val finalSchema = tableSchema.columnSchema.finalSchema
// Build primary key matching condition
val pks = tableSchema.getPrimaryKey().flatten()
val pkEquivalent =
if (importType.primaryKey.isNotEmpty()) {
importType.primaryKey.joinToString(" AND ") { fieldPath ->
val fieldName = fieldPath.first()
val columnName = columnNameMapping[fieldName] ?: fieldName
if (pks.isNotEmpty()) {
pks.joinToString(" AND ") { columnName ->
val targetTableColumnName = "target_table.${columnName.quote()}"
val newRecordColumnName = "new_record.${columnName.quote()}"
"""($targetTableColumnName = $newRecordColumnName OR ($targetTableColumnName IS NULL AND $newRecordColumnName IS NULL))"""
@@ -120,80 +142,62 @@ class SnowflakeDirectLoadSqlGenerator(
}
// Build column lists for INSERT and UPDATE
val columnList: String =
columnUtils
.getFormattedColumnNames(
columns = stream.schema.asColumns(),
columnNameMapping = columnNameMapping,
quote = false,
)
.joinToString(
",\n",
) {
it.quote()
}
val allColumns = buildList {
add(SNOWFLAKE_AB_RAW_ID)
add(SNOWFLAKE_AB_EXTRACTED_AT)
add(SNOWFLAKE_AB_META)
add(SNOWFLAKE_AB_GENERATION_ID)
addAll(finalSchema.keys)
}
val columnList: String = allColumns.joinToString(",\n ") { it.quote() }
val newRecordColumnList: String =
columnUtils
.getFormattedColumnNames(
columns = stream.schema.asColumns(),
columnNameMapping = columnNameMapping,
quote = false,
)
.joinToString(",\n") { "new_record.${it.quote()}" }
allColumns.joinToString(",\n ") { "new_record.${it.quote()}" }
// Get deduped records from source
val selectSourceRecords = selectDedupedRecords(stream, sourceTableName, columnNameMapping)
val selectSourceRecords = selectDedupedRecords(tableSchema, sourceTableName)
// Build cursor comparison for determining which record is newer
val cursorComparison: String
if (importType.cursor.isNotEmpty()) {
val cursorFieldName = importType.cursor.first()
val cursor = (columnNameMapping[cursorFieldName] ?: cursorFieldName)
val cursor = tableSchema.getCursor().firstOrNull()
if (cursor != null) {
val targetTableCursor = "target_table.${cursor.quote()}"
val newRecordCursor = "new_record.${cursor.quote()}"
cursorComparison =
"""
(
$targetTableCursor < $newRecordCursor
OR ($targetTableCursor = $newRecordCursor AND target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}")
OR ($targetTableCursor IS NULL AND $newRecordCursor IS NULL AND target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}")
OR ($targetTableCursor = $newRecordCursor AND target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT")
OR ($targetTableCursor IS NULL AND $newRecordCursor IS NULL AND target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT")
OR ($targetTableCursor IS NULL AND $newRecordCursor IS $NOT_NULL)
)
""".trimIndent()
} else {
// No cursor - use extraction timestamp only
cursorComparison =
"""target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}""""
"""target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT""""
}
// Build column assignments for UPDATE
val columnAssignments: String =
columnUtils
.getFormattedColumnNames(
columns = stream.schema.asColumns(),
columnNameMapping = columnNameMapping,
quote = false,
)
.joinToString(",\n") { column ->
"${column.quote()} = new_record.${column.quote()}"
}
allColumns.joinToString(",\n ") { column ->
"${column.quote()} = new_record.${column.quote()}"
}
// Handle CDC deletions based on mode
val cdcDeleteClause: String
val cdcSkipInsertClause: String
if (
stream.schema.asColumns().containsKey(CDC_DELETED_AT_COLUMN) &&
snowflakeConfiguration.cdcDeletionMode == CdcDeletionMode.HARD_DELETE
finalSchema.containsKey(SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN) &&
config.cdcDeletionMode == CdcDeletionMode.HARD_DELETE
) {
// Execute CDC deletions if there's already a record
cdcDeleteClause =
"WHEN MATCHED AND new_record.\"${CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()}\" IS NOT NULL AND $cursorComparison THEN DELETE"
"WHEN MATCHED AND new_record.\"${SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN}\" IS NOT NULL AND $cursorComparison THEN DELETE"
// And skip insertion entirely if there's no matching record.
// (This is possible if a single T+D batch contains both an insertion and deletion for
// the same PK)
cdcSkipInsertClause =
"AND new_record.\"${CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()}\" IS NULL"
cdcSkipInsertClause = "AND new_record.\"${SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN}\" IS NULL"
} else {
cdcDeleteClause = ""
cdcSkipInsertClause = ""
@@ -203,35 +207,35 @@ class SnowflakeDirectLoadSqlGenerator(
val mergeStatement =
if (cdcDeleteClause.isNotEmpty()) {
"""
MERGE INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} AS target_table
USING (
$selectSourceRecords
) AS new_record
ON $pkEquivalent
$cdcDeleteClause
WHEN MATCHED AND $cursorComparison THEN UPDATE SET
$columnAssignments
WHEN NOT MATCHED $cdcSkipInsertClause THEN INSERT (
$columnList
) VALUES (
$newRecordColumnList
)
""".trimIndent()
|MERGE INTO ${fullyQualifiedName(targetTableName)} AS target_table
|USING (
|$selectSourceRecords
|) AS new_record
|ON $pkEquivalent
|$cdcDeleteClause
|WHEN MATCHED AND $cursorComparison THEN UPDATE SET
| $columnAssignments
|WHEN NOT MATCHED $cdcSkipInsertClause THEN INSERT (
| $columnList
|) VALUES (
| $newRecordColumnList
|)
""".trimMargin()
} else {
"""
MERGE INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} AS target_table
USING (
$selectSourceRecords
) AS new_record
ON $pkEquivalent
WHEN MATCHED AND $cursorComparison THEN UPDATE SET
$columnAssignments
WHEN NOT MATCHED THEN INSERT (
$columnList
) VALUES (
$newRecordColumnList
)
""".trimIndent()
|MERGE INTO ${fullyQualifiedName(targetTableName)} AS target_table
|USING (
|$selectSourceRecords
|) AS new_record
|ON $pkEquivalent
|WHEN MATCHED AND $cursorComparison THEN UPDATE SET
| $columnAssignments
|WHEN NOT MATCHED THEN INSERT (
| $columnList
|) VALUES (
| $newRecordColumnList
|)
""".trimMargin()
}
return mergeStatement.andLog()
@@ -242,75 +246,66 @@ class SnowflakeDirectLoadSqlGenerator(
* table. Uses ROW_NUMBER() window function to select the most recent record per primary key.
*/
private fun selectDedupedRecords(
stream: DestinationStream,
sourceTableName: TableName,
columnNameMapping: ColumnNameMapping
tableSchema: StreamTableSchema,
sourceTableName: TableName
): String {
val columnList: String =
columnUtils
.getFormattedColumnNames(
columns = stream.schema.asColumns(),
columnNameMapping = columnNameMapping,
quote = false,
)
.joinToString(
",\n",
) {
it.quote()
}
val importType = stream.importType as Dedupe
val allColumns = buildList {
add(SNOWFLAKE_AB_RAW_ID)
add(SNOWFLAKE_AB_EXTRACTED_AT)
add(SNOWFLAKE_AB_META)
add(SNOWFLAKE_AB_GENERATION_ID)
addAll(tableSchema.columnSchema.finalSchema.keys)
}
val columnList: String = allColumns.joinToString(",\n ") { it.quote() }
// Build the primary key list for partitioning
val pks = tableSchema.getPrimaryKey().flatten()
val pkList =
if (importType.primaryKey.isNotEmpty()) {
importType.primaryKey.joinToString(",") { fieldPath ->
(columnNameMapping[fieldPath.first()] ?: fieldPath.first()).quote()
}
if (pks.isNotEmpty()) {
pks.joinToString(",") { it.quote() }
} else {
// Should not happen as we check this earlier, but handle it defensively
throw IllegalArgumentException("Cannot deduplicate without primary key")
}
// Build cursor order clause for sorting within each partition
val cursor = tableSchema.getCursor().firstOrNull()
val cursorOrderClause =
if (importType.cursor.isNotEmpty()) {
val columnName =
(columnNameMapping[importType.cursor.first()] ?: importType.cursor.first())
.quote()
"$columnName DESC NULLS LAST,"
if (cursor != null) {
"${cursor.quote()} DESC NULLS LAST,"
} else {
""
}
return """
WITH records AS (
SELECT
$columnList
FROM ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)}
), numbered_rows AS (
SELECT *, ROW_NUMBER() OVER (
PARTITION BY $pkList ORDER BY $cursorOrderClause "${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" DESC
) AS row_number
FROM records
)
SELECT $columnList
FROM numbered_rows
WHERE row_number = 1
| WITH records AS (
| SELECT
| $columnList
| FROM ${fullyQualifiedName(sourceTableName)}
| ), numbered_rows AS (
| SELECT *, ROW_NUMBER() OVER (
| PARTITION BY $pkList ORDER BY $cursorOrderClause "$SNOWFLAKE_AB_EXTRACTED_AT" DESC
| ) AS row_number
| FROM records
| )
| SELECT $columnList
| FROM numbered_rows
| WHERE row_number = 1
"""
.trimIndent()
.trimMargin()
.andLog()
}
fun dropTable(tableName: TableName): String {
return "DROP TABLE IF EXISTS ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog()
return "DROP TABLE IF EXISTS ${fullyQualifiedName(tableName)}".andLog()
}
fun getGenerationId(
tableName: TableName,
): String {
return """
SELECT "${columnUtils.getGenerationIdColumnName()}"
FROM ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}
SELECT "${columnManager.getGenerationIdColumnName()}"
FROM ${fullyQualifiedName(tableName)}
LIMIT 1
"""
.trimIndent()
@@ -318,12 +313,12 @@ class SnowflakeDirectLoadSqlGenerator(
}
fun createSnowflakeStage(tableName: TableName): String {
val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName)
val stageName = fullyQualifiedStageName(tableName)
return "CREATE STAGE IF NOT EXISTS $stageName".andLog()
}
fun putInStage(tableName: TableName, tempFilePath: String): String {
val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true)
val stageName = fullyQualifiedStageName(tableName, true)
return """
PUT 'file://$tempFilePath' '@$stageName'
AUTO_COMPRESS = FALSE
@@ -334,35 +329,45 @@ class SnowflakeDirectLoadSqlGenerator(
.andLog()
}
fun copyFromStage(tableName: TableName, filename: String): String {
val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true)
fun copyFromStage(
tableName: TableName,
filename: String,
columnNames: List<String>? = null
): String {
val stageName = fullyQualifiedStageName(tableName, true)
val columnList =
columnNames?.let { names -> "(${names.joinToString(", ") { it.quote() }})" } ?: ""
return """
COPY INTO ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}
FROM '@$stageName'
FILE_FORMAT = (
TYPE = 'CSV'
COMPRESSION = GZIP
FIELD_DELIMITER = '$CSV_FIELD_SEPARATOR'
RECORD_DELIMITER = '$CSV_LINE_DELIMITER'
FIELD_OPTIONALLY_ENCLOSED_BY = '"'
TRIM_SPACE = TRUE
ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE
REPLACE_INVALID_CHARACTERS = TRUE
ESCAPE = NONE
ESCAPE_UNENCLOSED_FIELD = NONE
)
ON_ERROR = 'ABORT_STATEMENT'
PURGE = TRUE
files = ('$filename')
|COPY INTO ${fullyQualifiedName(tableName)}$columnList
|FROM '@$stageName'
|FILE_FORMAT = (
| TYPE = 'CSV'
| COMPRESSION = GZIP
| FIELD_DELIMITER = '$CSV_FIELD_SEPARATOR'
| RECORD_DELIMITER = '$CSV_LINE_DELIMITER'
| FIELD_OPTIONALLY_ENCLOSED_BY = '"'
| TRIM_SPACE = TRUE
| ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE
| REPLACE_INVALID_CHARACTERS = TRUE
| ESCAPE = NONE
| ESCAPE_UNENCLOSED_FIELD = NONE
|)
|ON_ERROR = 'ABORT_STATEMENT'
|PURGE = TRUE
|files = ('$filename')
"""
.trimIndent()
.trimMargin()
.andLog()
}
fun swapTableWith(sourceTableName: TableName, targetTableName: TableName): String {
return """
ALTER TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} SWAP WITH ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)}
ALTER TABLE ${fullyQualifiedName(sourceTableName)} SWAP WITH ${
fullyQualifiedName(
targetTableName,
)
}
"""
.trimIndent()
.andLog()
@@ -372,7 +377,11 @@ class SnowflakeDirectLoadSqlGenerator(
// Snowflake RENAME TO only accepts the table name, not a fully qualified name
// The renamed table stays in the same schema
return """
ALTER TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} RENAME TO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)}
ALTER TABLE ${fullyQualifiedName(sourceTableName)} RENAME TO ${
fullyQualifiedName(
targetTableName,
)
}
"""
.trimIndent()
.andLog()
@@ -382,7 +391,7 @@ class SnowflakeDirectLoadSqlGenerator(
schemaName: String,
tableName: String,
): String =
"""DESCRIBE TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(TableName(schemaName, tableName))}""".andLog()
"""DESCRIBE TABLE ${fullyQualifiedName(TableName(schemaName, tableName))}""".andLog()
fun alterTable(
tableName: TableName,
@@ -391,14 +400,14 @@ class SnowflakeDirectLoadSqlGenerator(
modifiedColumns: Map<String, ColumnTypeChange>,
): Set<String> {
val clauses = mutableSetOf<String>()
val prettyTableName = snowflakeSqlNameUtils.fullyQualifiedName(tableName)
val prettyTableName = fullyQualifiedName(tableName)
addedColumns.forEach { (name, columnType) ->
clauses.add(
// Note that we intentionally don't set NOT NULL.
// We're adding a new column, and we don't know what constitutes a reasonable
// default value for preexisting records.
// So we add the column as nullable.
"ALTER TABLE $prettyTableName ADD COLUMN ${name.quote()} ${columnType.type};".andLog()
"ALTER TABLE $prettyTableName ADD COLUMN ${name.quote()} ${columnType.type};".andLog(),
)
}
deletedColumns.forEach {
@@ -412,35 +421,34 @@ class SnowflakeDirectLoadSqlGenerator(
val tempColumn = "${name}_${uuidGenerator.v4()}"
clauses.add(
// As above: we add the column as nullable.
"ALTER TABLE $prettyTableName ADD COLUMN ${tempColumn.quote()} ${typeChange.newType.type};".andLog()
"ALTER TABLE $prettyTableName ADD COLUMN ${tempColumn.quote()} ${typeChange.newType.type};".andLog(),
)
clauses.add(
"UPDATE $prettyTableName SET ${tempColumn.quote()} = CAST(${name.quote()} AS ${typeChange.newType.type});".andLog()
"UPDATE $prettyTableName SET ${tempColumn.quote()} = CAST(${name.quote()} AS ${typeChange.newType.type});".andLog(),
)
val backupColumn = "${tempColumn}_backup"
clauses.add(
"""
ALTER TABLE $prettyTableName
RENAME COLUMN "$name" TO "$backupColumn";
""".trimIndent()
""".trimIndent(),
)
clauses.add(
"""
ALTER TABLE $prettyTableName
RENAME COLUMN "$tempColumn" TO "$name";
""".trimIndent()
""".trimIndent(),
)
clauses.add(
"ALTER TABLE $prettyTableName DROP COLUMN ${backupColumn.quote()};".andLog()
"ALTER TABLE $prettyTableName DROP COLUMN ${backupColumn.quote()};".andLog(),
)
} else if (!typeChange.originalType.nullable && typeChange.newType.nullable) {
// If the type is unchanged, we can change a column from NOT NULL to nullable.
// But we'll never do the reverse, because there's a decent chance that historical
// records
// had null values.
// records had null values.
// Users can always manually ALTER COLUMN ... SET NOT NULL if they want.
clauses.add(
"""ALTER TABLE $prettyTableName ALTER COLUMN "$name" DROP NOT NULL;""".andLog()
"""ALTER TABLE $prettyTableName ALTER COLUMN "$name" DROP NOT NULL;""".andLog(),
)
} else {
log.info {
@@ -450,4 +458,45 @@ class SnowflakeDirectLoadSqlGenerator(
}
return clauses
}
@VisibleForTesting
fun fullyQualifiedName(tableName: TableName): String =
combineParts(listOf(getDatabaseName(), tableName.namespace, tableName.name))
@VisibleForTesting
fun fullyQualifiedNamespace(namespace: String) =
combineParts(listOf(getDatabaseName(), namespace))
@VisibleForTesting
fun fullyQualifiedStageName(tableName: TableName, escape: Boolean = false): String {
val currentTableName =
if (escape) {
tableName.name
} else {
tableName.name
}
return combineParts(
parts =
listOf(
getDatabaseName(),
tableName.namespace,
"$STAGE_NAME_PREFIX$currentTableName",
),
escape = escape,
)
}
@VisibleForTesting
internal fun combineParts(parts: List<String>, escape: Boolean = false): String =
parts
.map { if (escape) sqlEscape(it) else it }
.joinToString(separator = ".") {
if (!it.startsWith(QUOTE)) {
"$QUOTE$it$QUOTE"
} else {
it
}
}
private fun getDatabaseName() = config.database.toSnowflakeCompatibleName()
}

View File

@@ -0,0 +1,29 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
const val STAGE_NAME_PREFIX = "airbyte_stage_"
internal const val QUOTE: String = "\""
fun sqlEscape(part: String) = part.replace("\\", "\\\\").replace("'", "\\'").replace("\"", "\\\"")
/**
* Surrounds the string instance with double quotation marks (e.g. "some string" -> "\"some
* string\"").
*/
fun String.quote() = "$QUOTE$this$QUOTE"
/**
* Escapes double-quotes in a JSON identifier by doubling them. This is legacy -- I don't know why
* this would be necessary but no harm in keeping it, so I am keeping it.
*
* @return The escaped identifier.
*/
fun escapeJsonIdentifier(identifier: String): String {
// Note that we don't need to escape backslashes here!
// The only special character in an identifier is the double-quote, which needs to be
// doubled.
return identifier.replace(QUOTE, "$QUOTE$QUOTE")
}

View File

@@ -1,57 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
import io.airbyte.cdk.load.table.TableName
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import jakarta.inject.Singleton
const val STAGE_NAME_PREFIX = "airbyte_stage_"
internal const val QUOTE: String = "\""
fun sqlEscape(part: String) = part.replace("\\", "\\\\").replace("'", "\\'").replace("\"", "\\\"")
@Singleton
class SnowflakeSqlNameUtils(
private val snowflakeConfiguration: SnowflakeConfiguration,
) {
fun fullyQualifiedName(tableName: TableName): String =
combineParts(listOf(getDatabaseName(), tableName.namespace, tableName.name))
fun fullyQualifiedNamespace(namespace: String) =
combineParts(listOf(getDatabaseName(), namespace))
fun fullyQualifiedStageName(tableName: TableName, escape: Boolean = false): String {
val currentTableName =
if (escape) {
tableName.name
} else {
tableName.name
}
return combineParts(
parts =
listOf(
getDatabaseName(),
tableName.namespace,
"$STAGE_NAME_PREFIX$currentTableName"
),
escape = escape,
)
}
fun combineParts(parts: List<String>, escape: Boolean = false): String =
parts
.map { if (escape) sqlEscape(it) else it }
.joinToString(separator = ".") {
if (!it.startsWith(QUOTE)) {
"$QUOTE$it$QUOTE"
} else {
it
}
}
private fun getDatabaseName() = snowflakeConfiguration.database.toSnowflakeCompatibleName()
}

View File

@@ -6,38 +6,39 @@ package io.airbyte.integrations.destination.snowflake.write
import io.airbyte.cdk.SystemErrorException
import io.airbyte.cdk.load.command.Dedupe
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.orchestration.db.DatabaseInitialStatusGatherer
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadInitialStatus
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendStreamLoader
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendTruncateStreamLoader
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableDedupStreamLoader
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableDedupTruncateStreamLoader
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.DatabaseInitialStatusGatherer
import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.cdk.load.table.directload.DirectLoadInitialStatus
import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendStreamLoader
import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendTruncateStreamLoader
import io.airbyte.cdk.load.table.directload.DirectLoadTableDedupStreamLoader
import io.airbyte.cdk.load.table.directload.DirectLoadTableDedupTruncateStreamLoader
import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.write.DestinationWriter
import io.airbyte.cdk.load.write.StreamLoader
import io.airbyte.cdk.load.write.StreamStateStore
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
import jakarta.inject.Singleton
@Singleton
class SnowflakeWriter(
private val names: TableCatalog,
private val catalog: DestinationCatalog,
private val stateGatherer: DatabaseInitialStatusGatherer<DirectLoadInitialStatus>,
private val streamStateStore: StreamStateStore<DirectLoadTableExecutionConfig>,
private val snowflakeClient: SnowflakeAirbyteClient,
private val tempTableNameGenerator: TempTableNameGenerator,
private val snowflakeConfiguration: SnowflakeConfiguration,
private val tempTableNameGenerator: TempTableNameGenerator,
) : DestinationWriter {
private lateinit var initialStatuses: Map<DestinationStream, DirectLoadInitialStatus>
override suspend fun setup() {
names.values
.map { (tableNames, _) -> tableNames.finalTableName!!.namespace }
catalog.streams
.map { it.tableSchema.tableNames.finalTableName!!.namespace }
.toSet()
.forEach { snowflakeClient.createNamespace(it) }
@@ -45,15 +46,15 @@ class SnowflakeWriter(
escapeJsonIdentifier(snowflakeConfiguration.internalTableSchema)
)
initialStatuses = stateGatherer.gatherInitialStatus(names)
initialStatuses = stateGatherer.gatherInitialStatus()
}
override fun createStreamLoader(stream: DestinationStream): StreamLoader {
val initialStatus = initialStatuses[stream]!!
val tableNameInfo = names[stream]!!
val realTableName = tableNameInfo.tableNames.finalTableName!!
val tempTableName = tempTableNameGenerator.generate(realTableName)
val columnNameMapping = tableNameInfo.columnNameMapping
val realTableName = stream.tableSchema.tableNames.finalTableName!!
val tempTableName = stream.tableSchema.tableNames.tempTableName!!
val columnNameMapping =
ColumnNameMapping(stream.tableSchema.columnSchema.inputToFinalColumnNames)
return when (stream.minimumGenerationId) {
0L ->
when (stream.importType) {

View File

@@ -9,11 +9,12 @@ import de.siegmar.fastcsv.writer.CsvWriter
import de.siegmar.fastcsv.writer.LineDelimiter
import de.siegmar.fastcsv.writer.QuoteStrategies
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.QUOTE
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.github.oshai.kotlinlogging.KotlinLogging
import java.io.File
import java.io.OutputStream
@@ -36,10 +37,11 @@ private const val CSV_WRITER_BUFFER_SIZE = 1024 * 1024 // 1 MB
class SnowflakeInsertBuffer(
private val tableName: TableName,
val columns: LinkedHashMap<String, String>,
private val snowflakeClient: SnowflakeAirbyteClient,
val snowflakeConfiguration: SnowflakeConfiguration,
val snowflakeColumnUtils: SnowflakeColumnUtils,
val columnSchema: ColumnSchema,
private val columnManager: SnowflakeColumnManager,
private val snowflakeRecordFormatter: SnowflakeRecordFormatter,
private val flushLimit: Int = DEFAULT_FLUSH_LIMIT,
) {
@@ -57,12 +59,6 @@ class SnowflakeInsertBuffer(
.lineDelimiter(CSV_LINE_DELIMITER)
.quoteStrategy(QuoteStrategies.REQUIRED)
private val snowflakeRecordFormatter: SnowflakeRecordFormatter =
when (snowflakeConfiguration.legacyRawTablesOnly) {
true -> SnowflakeRawRecordFormatter(columns, snowflakeColumnUtils)
else -> SnowflakeSchemaRecordFormatter(columns, snowflakeColumnUtils)
}
fun accumulate(recordFields: Map<String, AirbyteValue>) {
if (csvFilePath == null) {
val csvFile = createCsvFile()
@@ -92,7 +88,9 @@ class SnowflakeInsertBuffer(
"Copying staging data into ${tableName.toPrettyString(quote = QUOTE)}..."
}
// Finally, copy the data from the staging table to the final table
snowflakeClient.copyFromStage(tableName, filePath.fileName.toString())
// Pass column names to ensure correct mapping even after ALTER TABLE operations
val columnNames = columnManager.getTableColumnNames(columnSchema)
snowflakeClient.copyFromStage(tableName, filePath.fileName.toString(), columnNames)
logger.info {
"Finished insert of $recordCount row(s) into ${tableName.toPrettyString(quote = QUOTE)}."
}
@@ -117,7 +115,9 @@ class SnowflakeInsertBuffer(
private fun writeToCsvFile(record: Map<String, AirbyteValue>) {
csvWriter?.let {
it.writeRecord(snowflakeRecordFormatter.format(record).map { col -> col.toString() })
it.writeRecord(
snowflakeRecordFormatter.format(record, columnSchema).map { col -> col.toString() }
)
recordCount++
if ((recordCount % flushLimit) == 0) {
it.flush()

View File

@@ -8,99 +8,62 @@ import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.NullValue
import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.data.csv.toCsvValue
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.util.Jsons
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
interface SnowflakeRecordFormatter {
fun format(record: Map<String, AirbyteValue>): List<Any>
fun format(record: Map<String, AirbyteValue>, columnSchema: ColumnSchema): List<Any>
}
class SnowflakeSchemaRecordFormatter(
private val columns: LinkedHashMap<String, String>,
val snowflakeColumnUtils: SnowflakeColumnUtils,
) : SnowflakeRecordFormatter {
class SnowflakeSchemaRecordFormatter : SnowflakeRecordFormatter {
override fun format(record: Map<String, AirbyteValue>, columnSchema: ColumnSchema): List<Any> {
val result = mutableListOf<Any>()
val userColumns = columnSchema.finalSchema.keys
private val airbyteColumnNames =
snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet()
// WARNING: MUST match the order defined in SnowflakeColumnManager#getTableColumnNames
//
// Why don't we just use that here? Well, unlike the user fields, the meta fields on the
// record are not munged for the destination. So we must access the values for those columns
// using the original lowercase meta key.
result.add(record[COLUMN_NAME_AB_RAW_ID].toCsvValue())
result.add(record[COLUMN_NAME_AB_EXTRACTED_AT].toCsvValue())
result.add(record[COLUMN_NAME_AB_META].toCsvValue())
result.add(record[COLUMN_NAME_AB_GENERATION_ID].toCsvValue())
override fun format(record: Map<String, AirbyteValue>): List<Any> =
columns.map { (columnName, _) ->
/*
* Meta columns are forced to uppercase for backwards compatibility with previous
* versions of the destination. Therefore, convert the column to lowercase so
* that it can match the constants, which use the lowercase version of the meta
* column names.
*/
if (airbyteColumnNames.contains(columnName)) {
record[columnName.lowercase()].toCsvValue()
} else {
record.keys
// The columns retrieved from Snowflake do not have any escaping applied.
// Therefore, re-apply the compatible name escaping to the name of the
// columns retrieved from Snowflake. The record keys should already have
// been escaped by the CDK before arriving at the aggregate, so no need
// to escape again here.
.find { it == columnName.toSnowflakeCompatibleName() }
?.let { record[it].toCsvValue() }
?: ""
}
}
// Add user columns from the final schema
userColumns.forEach { columnName -> result.add(record[columnName].toCsvValue()) }
return result
}
}
class SnowflakeRawRecordFormatter(
columns: LinkedHashMap<String, String>,
val snowflakeColumnUtils: SnowflakeColumnUtils,
) : SnowflakeRecordFormatter {
private val columns = columns.keys
class SnowflakeRawRecordFormatter : SnowflakeRecordFormatter {
private val airbyteColumnNames =
snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet()
override fun format(record: Map<String, AirbyteValue>): List<Any> =
override fun format(record: Map<String, AirbyteValue>, columnSchema: ColumnSchema): List<Any> =
toOutputRecord(record.toMutableMap())
private fun toOutputRecord(record: MutableMap<String, AirbyteValue>): List<Any> {
val outputRecord = mutableListOf<Any>()
// Copy the Airbyte metadata columns to the raw output, removing each
// one from the record to avoid duplicates in the "data" field
columns
.filter { airbyteColumnNames.contains(it) && it != Meta.COLUMN_NAME_DATA }
.forEach { column -> safeAddToOutput(column, record, outputRecord) }
val mutableRecord = record.toMutableMap()
// Add meta columns in order (except _airbyte_data which we handle specially)
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_RAW_ID)?.toCsvValue() ?: "")
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_EXTRACTED_AT)?.toCsvValue() ?: "")
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_META)?.toCsvValue() ?: "")
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_GENERATION_ID)?.toCsvValue() ?: "")
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_LOADED_AT)?.toCsvValue() ?: "")
// Do not output null values in the JSON raw output
val filteredRecord = record.filter { (_, v) -> v !is NullValue }
// Convert all the remaining columns in the record to a JSON document stored in the "data"
// column. Add it in the same position as the _airbyte_data column in the column list to
// ensure it is inserted into the proper column in the table.
insert(
columns.indexOf(Meta.COLUMN_NAME_DATA),
StringValue(Jsons.writeValueAsString(filteredRecord)).toCsvValue(),
outputRecord
)
val filteredRecord = mutableRecord.filter { (_, v) -> v !is NullValue }
// Convert all the remaining columns to a JSON document stored in the "data" column
outputRecord.add(StringValue(Jsons.writeValueAsString(filteredRecord)).toCsvValue())
return outputRecord
}
private fun safeAddToOutput(
key: String,
record: MutableMap<String, AirbyteValue>,
output: MutableList<Any>
) {
val extractedValue = record.remove(key)
// Ensure that the data is inserted into the list at the same position as the column
insert(columns.indexOf(key), extractedValue?.toCsvValue() ?: "", output)
}
private fun insert(index: Int, value: Any, list: MutableList<Any>) {
/*
* Attempt to insert the value into the proper order in the list. If the index
* is already present in the list, use the add(index, element) method to insert it
* into the proper order and push everything to the right. If the index is at the
* end of the list, just use add(element) to insert it at the end. If the index
* is further beyond the end of the list, throw an exception as that should not occur.
*/
if (index < list.size) list.add(index, value)
else if (index == list.size || index == list.size + 1) list.add(value)
else throw IndexOutOfBoundsException()
}
}

View File

@@ -1,25 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.write.transform
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.dataflow.transform.ColumnNameMapper
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import jakarta.inject.Singleton
@Singleton
class SnowflakeColumnNameMapper(
private val catalogInfo: TableCatalog,
private val snowflakeConfiguration: SnowflakeConfiguration,
) : ColumnNameMapper {
override fun getMappedColumnName(stream: DestinationStream, columnName: String): String {
if (snowflakeConfiguration.legacyRawTablesOnly == true) {
return columnName
} else {
return catalogInfo.getMappedColumnName(stream, columnName)!!
}
}
}

View File

@@ -7,9 +7,11 @@ package io.airbyte.integrations.destination.snowflake.component
import io.airbyte.cdk.load.component.TableOperationsFixtures
import io.airbyte.cdk.load.component.TableOperationsSuite
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.TableSchemaFactory
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.idTestWithCdcMapping
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.testMapping
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.idTestWithCdcMapping
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.testMapping
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeTestTableOperationsClient
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.parallel.Execution
@@ -20,6 +22,7 @@ import org.junit.jupiter.api.parallel.ExecutionMode
class SnowflakeTableOperationsTest(
override val client: SnowflakeAirbyteClient,
override val testClient: SnowflakeTestTableOperationsClient,
override val schemaFactory: TableSchemaFactory,
) : TableOperationsSuite {
override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() }

View File

@@ -9,13 +9,15 @@ import io.airbyte.cdk.load.component.TableSchemaEvolutionFixtures
import io.airbyte.cdk.load.component.TableSchemaEvolutionSuite
import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.TableSchemaFactory
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.util.serializeToString
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.allTypesColumnNameMapping
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.allTypesTableSchema
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.idAndTestMapping
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.testMapping
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.allTypesColumnNameMapping
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.allTypesTableSchema
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.idAndTestMapping
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.testMapping
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeTestTableOperationsClient
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.parallel.Execution
@@ -27,6 +29,7 @@ class SnowflakeTableSchemaEvolutionTest(
override val client: SnowflakeAirbyteClient,
override val opsClient: SnowflakeAirbyteClient,
override val testClient: SnowflakeTestTableOperationsClient,
override val schemaFactory: TableSchemaFactory,
) : TableSchemaEvolutionSuite {
override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() }

View File

@@ -2,7 +2,7 @@
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.component
package io.airbyte.integrations.destination.snowflake.component.config
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.component.TableOperationsFixtures
@@ -33,6 +33,8 @@ object SnowflakeComponentTestFixtures {
"TIME_NTZ" to ColumnType("TIME", true),
"ARRAY" to ColumnType("ARRAY", true),
"OBJECT" to ColumnType("OBJECT", true),
"UNION" to ColumnType("VARIANT", true),
"LEGACY_UNION" to ColumnType("VARIANT", true),
"UNKNOWN" to ColumnType("VARIANT", true),
)
)

View File

@@ -2,7 +2,7 @@
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.component
package io.airbyte.integrations.destination.snowflake.component.config
import io.airbyte.cdk.load.component.config.TestConfigLoader.loadTestConfig
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration

View File

@@ -2,22 +2,24 @@
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.component
package io.airbyte.integrations.destination.snowflake.component.config
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.component.TestTableOperationsClient
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.dataflow.state.PartitionKey
import io.airbyte.cdk.load.dataflow.transform.RecordDTO
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.util.Jsons
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.client.execute
import io.airbyte.integrations.destination.snowflake.dataflow.SnowflakeAggregate
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
import io.airbyte.integrations.destination.snowflake.sql.andLog
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
import io.micronaut.context.annotation.Requires
import jakarta.inject.Singleton
import java.time.format.DateTimeFormatter
@@ -29,25 +31,40 @@ import net.snowflake.client.jdbc.SnowflakeTimestampWithTimezone
class SnowflakeTestTableOperationsClient(
private val client: SnowflakeAirbyteClient,
private val dataSource: DataSource,
private val snowflakeSqlNameUtils: SnowflakeSqlNameUtils,
private val snowflakeColumnUtils: SnowflakeColumnUtils,
private val sqlGenerator: SnowflakeDirectLoadSqlGenerator,
private val snowflakeConfiguration: SnowflakeConfiguration,
private val columnManager: SnowflakeColumnManager,
private val snowflakeRecordFormatter: SnowflakeRecordFormatter,
) : TestTableOperationsClient {
override suspend fun dropNamespace(namespace: String) {
dataSource.execute(
"DROP SCHEMA IF EXISTS ${snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)}".andLog()
"DROP SCHEMA IF EXISTS ${sqlGenerator.fullyQualifiedNamespace(namespace)}".andLog()
)
}
override suspend fun insertRecords(table: TableName, records: List<Map<String, AirbyteValue>>) {
// TODO: we should just pass a proper column schema
// Since we don't pass in a proper column schema, we have to recreate one here
// Fetch the columns and filter out the meta columns so we're just looking at user columns
val columnTypes =
client.describeTable(table).filterNot {
columnManager.getMetaColumnNames().contains(it.key)
}
val columnSchema =
io.airbyte.cdk.load.schema.model.ColumnSchema(
inputToFinalColumnNames = columnTypes.keys.associateWith { it },
finalSchema = columnTypes.mapValues { (_, _) -> ColumnType("", true) },
inputSchema = emptyMap() // Not needed for insert buffer
)
val a =
SnowflakeAggregate(
SnowflakeInsertBuffer(
table,
client.describeTable(table),
client,
snowflakeConfiguration,
snowflakeColumnUtils,
tableName = table,
snowflakeClient = client,
snowflakeConfiguration = snowflakeConfiguration,
columnSchema = columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
)
)
records.forEach { a.accept(RecordDTO(it, PartitionKey(""), 0, 0)) }

View File

@@ -7,11 +7,11 @@ package io.airbyte.integrations.destination.snowflake.write
import io.airbyte.cdk.load.test.util.DestinationCleaner
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier
import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory
import io.airbyte.integrations.destination.snowflake.sql.STAGE_NAME_PREFIX
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
import io.airbyte.integrations.destination.snowflake.sql.quote
import java.nio.file.Files
import java.sql.Connection

View File

@@ -12,16 +12,22 @@ import io.airbyte.cdk.load.data.ObjectValue
import io.airbyte.cdk.load.data.json.toAirbyteValue
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator
import io.airbyte.cdk.load.test.util.DestinationDataDumper
import io.airbyte.cdk.load.test.util.OutputRecord
import io.airbyte.cdk.load.util.UUIDGenerator
import io.airbyte.cdk.load.util.deserializeToNode
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
import io.airbyte.integrations.destination.snowflake.db.SnowflakeFinalTableNameGenerator
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeTableSchemaMapper
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
import io.airbyte.integrations.destination.snowflake.sql.sqlEscape
import java.math.BigDecimal
import java.sql.Date
import java.sql.Time
import java.sql.Timestamp
import net.snowflake.client.jdbc.SnowflakeTimestampWithTimezone
private val AIRBYTE_META_COLUMNS = Meta.COLUMN_NAMES + setOf(CDC_DELETED_AT_COLUMN)
@@ -34,8 +40,14 @@ class SnowflakeDataDumper(
stream: DestinationStream
): List<OutputRecord> {
val config = configProvider(spec)
val sqlUtils = SnowflakeSqlNameUtils(config)
val snowflakeFinalTableNameGenerator = SnowflakeFinalTableNameGenerator(config)
val snowflakeFinalTableNameGenerator =
SnowflakeTableSchemaMapper(
config = config,
tempTableNameGenerator = DefaultTempTableNameGenerator(),
)
val snowflakeColumnManager = SnowflakeColumnManager(config)
val sqlGenerator =
SnowflakeDirectLoadSqlGenerator(UUIDGenerator(), config, snowflakeColumnManager)
val dataSource =
SnowflakeBeanFactory()
.snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY")
@@ -46,7 +58,7 @@ class SnowflakeDataDumper(
ds.connection.use { connection ->
val statement = connection.createStatement()
val tableName =
snowflakeFinalTableNameGenerator.getTableName(stream.mappedDescriptor)
snowflakeFinalTableNameGenerator.toFinalTableName(stream.mappedDescriptor)
// First check if the table exists
val tableExistsQuery =
@@ -69,7 +81,7 @@ class SnowflakeDataDumper(
val resultSet =
statement.executeQuery(
"SELECT * FROM ${sqlUtils.fullyQualifiedName(tableName)}"
"SELECT * FROM ${sqlGenerator.fullyQualifiedName(tableName)}"
)
while (resultSet.next()) {
@@ -143,10 +155,10 @@ class SnowflakeDataDumper(
private fun convertValue(value: Any?): Any? =
when (value) {
is BigDecimal -> value.toBigInteger()
is java.sql.Date -> value.toLocalDate()
is Date -> value.toLocalDate()
is SnowflakeTimestampWithTimezone -> value.toZonedDateTime()
is java.sql.Time -> value.toLocalTime()
is java.sql.Timestamp -> value.toLocalDateTime()
is Time -> value.toLocalTime()
is Timestamp -> value.toLocalDateTime()
else -> value
}
}

View File

@@ -15,7 +15,7 @@ import io.airbyte.cdk.load.dataflow.transform.ValidationResult
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.test.util.ExpectedRecordMapper
import io.airbyte.cdk.load.test.util.OutputRecord
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.write.transform.SnowflakeValueCoercer
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Change

View File

@@ -5,7 +5,7 @@
package io.airbyte.integrations.destination.snowflake.write
import io.airbyte.cdk.load.test.util.NameMapper
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
class SnowflakeNameMapper : NameMapper {
override fun mapFieldName(path: List<String>): List<String> =

View File

@@ -9,13 +9,16 @@ import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.data.json.toAirbyteValue
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator
import io.airbyte.cdk.load.test.util.DestinationDataDumper
import io.airbyte.cdk.load.test.util.OutputRecord
import io.airbyte.cdk.load.util.UUIDGenerator
import io.airbyte.cdk.load.util.deserializeToNode
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
import io.airbyte.integrations.destination.snowflake.db.SnowflakeFinalTableNameGenerator
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeTableSchemaMapper
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
class SnowflakeRawDataDumper(
private val configProvider: (ConfigurationSpecification) -> SnowflakeConfiguration
@@ -27,8 +30,18 @@ class SnowflakeRawDataDumper(
val output = mutableListOf<OutputRecord>()
val config = configProvider(spec)
val sqlUtils = SnowflakeSqlNameUtils(config)
val snowflakeFinalTableNameGenerator = SnowflakeFinalTableNameGenerator(config)
val snowflakeColumnManager = SnowflakeColumnManager(config)
val sqlGenerator =
SnowflakeDirectLoadSqlGenerator(
UUIDGenerator(),
config,
snowflakeColumnManager,
)
val snowflakeFinalTableNameGenerator =
SnowflakeTableSchemaMapper(
config = config,
tempTableNameGenerator = DefaultTempTableNameGenerator(),
)
val dataSource =
SnowflakeBeanFactory()
.snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY")
@@ -37,11 +50,11 @@ class SnowflakeRawDataDumper(
ds.connection.use { connection ->
val statement = connection.createStatement()
val tableName =
snowflakeFinalTableNameGenerator.getTableName(stream.mappedDescriptor)
snowflakeFinalTableNameGenerator.toFinalTableName(stream.mappedDescriptor)
val resultSet =
statement.executeQuery(
"SELECT * FROM ${sqlUtils.fullyQualifiedName(tableName)}"
"SELECT * FROM ${sqlGenerator.fullyQualifiedName(tableName)}"
)
while (resultSet.next()) {

View File

@@ -6,7 +6,7 @@ package io.airbyte.integrations.destination.snowflake
import com.zaxxer.hikari.HikariConfig
import com.zaxxer.hikari.HikariDataSource
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode
import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration

View File

@@ -5,11 +5,9 @@
package io.airbyte.integrations.destination.snowflake.check
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.DEFAULT_COLUMNS
import io.airbyte.integrations.destination.snowflake.sql.RAW_DATA_COLUMN
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.every
@@ -23,47 +21,31 @@ internal class SnowflakeCheckerTest {
@ParameterizedTest
@ValueSource(booleans = [true, false])
fun testSuccessfulCheck(isLegacyRawTablesOnly: Boolean) {
val defaultColumnsMap =
if (isLegacyRawTablesOnly) {
linkedMapOf<String, String>().also { map ->
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
map[it.columnName] = it.columnType
}
}
} else {
linkedMapOf<String, String>().also { map ->
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
map[it.columnName.toSnowflakeCompatibleName()] = it.columnType
}
}
}
val defaultColumns = defaultColumnsMap.keys.toMutableList()
val snowflakeAirbyteClient: SnowflakeAirbyteClient =
mockk(relaxed = true) {
coEvery { countTable(any()) } returns 1L
coEvery { describeTable(any()) } returns defaultColumnsMap
}
mockk(relaxed = true) { coEvery { countTable(any()) } returns 1L }
val testSchema = "test-schema"
val snowflakeConfiguration: SnowflakeConfiguration = mockk {
every { schema } returns testSchema
every { legacyRawTablesOnly } returns isLegacyRawTablesOnly
}
val snowflakeColumnUtils =
mockk<SnowflakeColumnUtils>(relaxUnitFun = true) {
every { getFormattedDefaultColumnNames(any()) } returns defaultColumns
}
val columnManager: SnowflakeColumnManager = SnowflakeColumnManager(snowflakeConfiguration)
val checker =
SnowflakeChecker(
snowflakeAirbyteClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils,
columnManager = columnManager,
)
checker.check()
coVerify(exactly = 1) {
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
if (isLegacyRawTablesOnly) {
snowflakeAirbyteClient.createNamespace(testSchema)
} else {
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
}
}
coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) }
coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) }
@@ -72,48 +54,32 @@ internal class SnowflakeCheckerTest {
@ParameterizedTest
@ValueSource(booleans = [true, false])
fun testUnsuccessfulCheck(isLegacyRawTablesOnly: Boolean) {
val defaultColumnsMap =
if (isLegacyRawTablesOnly) {
linkedMapOf<String, String>().also { map ->
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
map[it.columnName] = it.columnType
}
}
} else {
linkedMapOf<String, String>().also { map ->
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
map[it.columnName.toSnowflakeCompatibleName()] = it.columnType
}
}
}
val defaultColumns = defaultColumnsMap.keys.toMutableList()
val snowflakeAirbyteClient: SnowflakeAirbyteClient =
mockk(relaxed = true) {
coEvery { countTable(any()) } returns 0L
coEvery { describeTable(any()) } returns defaultColumnsMap
}
mockk(relaxed = true) { coEvery { countTable(any()) } returns 0L }
val testSchema = "test-schema"
val snowflakeConfiguration: SnowflakeConfiguration = mockk {
every { schema } returns testSchema
every { legacyRawTablesOnly } returns isLegacyRawTablesOnly
}
val snowflakeColumnUtils =
mockk<SnowflakeColumnUtils>(relaxUnitFun = true) {
every { getFormattedDefaultColumnNames(any()) } returns defaultColumns
}
val columnManager: SnowflakeColumnManager = SnowflakeColumnManager(snowflakeConfiguration)
val checker =
SnowflakeChecker(
snowflakeAirbyteClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils,
columnManager = columnManager,
)
assertThrows<IllegalArgumentException> { checker.check() }
coVerify(exactly = 1) {
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
if (isLegacyRawTablesOnly) {
snowflakeAirbyteClient.createNamespace(testSchema)
} else {
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
}
}
coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) }
coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) }

View File

@@ -6,24 +6,16 @@ package io.airbyte.integrations.destination.snowflake.client
import io.airbyte.cdk.ConfigErrorException
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.command.NamespaceMapper
import io.airbyte.cdk.load.command.Overwrite
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.config.NamespaceDefinitionType
import io.airbyte.cdk.load.data.AirbyteType
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TableName
import io.airbyte.integrations.destination.snowflake.db.ColumnDefinition
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS
import io.airbyte.integrations.destination.snowflake.sql.ColumnAndType
import io.airbyte.integrations.destination.snowflake.sql.DEFAULT_COLUMNS
import io.airbyte.integrations.destination.snowflake.sql.QUOTE
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
import io.mockk.Runs
import io.mockk.every
@@ -49,31 +41,18 @@ internal class SnowflakeAirbyteClientTest {
private lateinit var client: SnowflakeAirbyteClient
private lateinit var dataSource: DataSource
private lateinit var sqlGenerator: SnowflakeDirectLoadSqlGenerator
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
private lateinit var columnManager: SnowflakeColumnManager
@BeforeEach
fun setup() {
dataSource = mockk()
sqlGenerator = mockk(relaxed = true)
snowflakeColumnUtils =
mockk(relaxed = true) {
every { formatColumnName(any()) } answers
{
firstArg<String>().toSnowflakeCompatibleName()
}
every { getFormattedDefaultColumnNames(any()) } returns
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() }
}
snowflakeConfiguration =
mockk(relaxed = true) { every { database } returns "test_database" }
columnManager = mockk(relaxed = true)
client =
SnowflakeAirbyteClient(
dataSource,
sqlGenerator,
snowflakeColumnUtils,
snowflakeConfiguration
)
SnowflakeAirbyteClient(dataSource, sqlGenerator, snowflakeConfiguration, columnManager)
}
@Test
@@ -231,7 +210,7 @@ internal class SnowflakeAirbyteClientTest {
@Test
fun testCreateTable() {
val columnNameMapping = mockk<ColumnNameMapping>(relaxed = true)
val stream = mockk<DestinationStream>()
val stream = mockk<DestinationStream>(relaxed = true)
val tableName = TableName(namespace = "namespace", name = "name")
val resultSet = mockk<ResultSet>(relaxed = true)
val statement =
@@ -254,9 +233,7 @@ internal class SnowflakeAirbyteClientTest {
columnNameMapping = columnNameMapping,
replace = true,
)
verify(exactly = 1) {
sqlGenerator.createTable(stream, tableName, columnNameMapping, true)
}
verify(exactly = 1) { sqlGenerator.createTable(tableName, any(), true) }
verify(exactly = 1) { sqlGenerator.createSnowflakeStage(tableName) }
verify(exactly = 2) { mockConnection.close() }
}
@@ -288,7 +265,7 @@ internal class SnowflakeAirbyteClientTest {
targetTableName = destinationTableName,
)
verify(exactly = 1) {
sqlGenerator.copyTable(columnNameMapping, sourceTableName, destinationTableName)
sqlGenerator.copyTable(any<Set<String>>(), sourceTableName, destinationTableName)
}
verify(exactly = 1) { mockConnection.close() }
}
@@ -299,7 +276,7 @@ internal class SnowflakeAirbyteClientTest {
val columnNameMapping = mockk<ColumnNameMapping>(relaxed = true)
val sourceTableName = TableName(namespace = "namespace", name = "source")
val destinationTableName = TableName(namespace = "namespace", name = "destination")
val stream = mockk<DestinationStream>()
val stream = mockk<DestinationStream>(relaxed = true)
val resultSet = mockk<ResultSet>(relaxed = true)
val statement =
mockk<Statement> {
@@ -322,12 +299,7 @@ internal class SnowflakeAirbyteClientTest {
targetTableName = destinationTableName,
)
verify(exactly = 1) {
sqlGenerator.upsertTable(
stream,
columnNameMapping,
sourceTableName,
destinationTableName
)
sqlGenerator.upsertTable(any(), sourceTableName, destinationTableName)
}
verify(exactly = 1) { mockConnection.close() }
}
@@ -379,7 +351,7 @@ internal class SnowflakeAirbyteClientTest {
}
every { dataSource.connection } returns mockConnection
every { snowflakeColumnUtils.getGenerationIdColumnName() } returns generationIdColumnName
every { columnManager.getGenerationIdColumnName() } returns generationIdColumnName
every { sqlGenerator.getGenerationId(tableName) } returns
"SELECT $generationIdColumnName FROM ${tableName.toPrettyString(QUOTE)}"
@@ -501,8 +473,8 @@ internal class SnowflakeAirbyteClientTest {
every { dataSource.connection } returns mockConnection
runBlocking {
client.copyFromStage(tableName, "test.csv.gz")
verify(exactly = 1) { sqlGenerator.copyFromStage(tableName, "test.csv.gz") }
client.copyFromStage(tableName, "test.csv.gz", listOf())
verify(exactly = 1) { sqlGenerator.copyFromStage(tableName, "test.csv.gz", listOf()) }
verify(exactly = 1) { mockConnection.close() }
}
}
@@ -556,7 +528,7 @@ internal class SnowflakeAirbyteClientTest {
"COL1" andThen
COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName() andThen
"COL2"
every { resultSet.getString("type") } returns "VARCHAR(255)" andThen "NUMBER(38,0)"
every { resultSet.getString("type") } returns "VARCHAR(255)" andThen "NUMBER"
every { resultSet.getString("null?") } returns "Y" andThen "N" andThen "N"
val statement =
@@ -571,6 +543,10 @@ internal class SnowflakeAirbyteClientTest {
every { dataSource.connection } returns connection
// Mock the columnManager to return the correct set of meta columns
every { columnManager.getMetaColumnNames() } returns
setOf(COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName())
val result = client.getColumnsFromDb(tableName)
val expectedColumns =
@@ -582,81 +558,6 @@ internal class SnowflakeAirbyteClientTest {
assertEquals(expectedColumns, result)
}
@Test
fun `getColumnsFromStream should return correct column definitions`() {
val schema = mockk<AirbyteType>()
val stream =
DestinationStream(
unmappedNamespace = "test_namespace",
unmappedName = "test_stream",
importType = Overwrite,
schema = schema,
generationId = 1,
minimumGenerationId = 1,
syncId = 1,
namespaceMapper = NamespaceMapper(NamespaceDefinitionType.DESTINATION)
)
val columnNameMapping =
ColumnNameMapping(
mapOf(
"col1" to "COL1_MAPPED",
"col2" to "COL2_MAPPED",
)
)
val col1FieldType = mockk<FieldType>()
every { col1FieldType.type } returns mockk()
val col2FieldType = mockk<FieldType>()
every { col2FieldType.type } returns mockk()
every { schema.asColumns() } returns
linkedMapOf("col1" to col1FieldType, "col2" to col2FieldType)
every { snowflakeColumnUtils.toDialectType(col1FieldType.type) } returns "VARCHAR(255)"
every { snowflakeColumnUtils.toDialectType(col2FieldType.type) } returns "NUMBER(38,0)"
every { snowflakeColumnUtils.columnsAndTypes(any(), any()) } returns
listOf(ColumnAndType("COL1_MAPPED", "VARCHAR"), ColumnAndType("COL2_MAPPED", "NUMBER"))
every { snowflakeColumnUtils.formatColumnName(any(), false) } answers
{
firstArg<String>().toSnowflakeCompatibleName()
}
val result = client.getColumnsFromStream(stream, columnNameMapping)
val expectedColumns =
mapOf(
"COL1_MAPPED" to ColumnType("VARCHAR", true),
"COL2_MAPPED" to ColumnType("NUMBER", true),
)
assertEquals(expectedColumns, result)
}
@Test
fun `generateSchemaChanges should correctly identify changes`() {
val columnsInDb =
setOf(
ColumnDefinition("COL1", "VARCHAR"),
ColumnDefinition("COL2", "NUMBER"),
ColumnDefinition("COL3", "BOOLEAN")
)
val columnsInStream =
setOf(
ColumnDefinition("COL1", "VARCHAR"), // Unchanged
ColumnDefinition("COL3", "TEXT"), // Modified
ColumnDefinition("COL4", "DATE") // Added
)
val (added, deleted, modified) = client.generateSchemaChanges(columnsInDb, columnsInStream)
assertEquals(1, added.size)
assertEquals("COL4", added.first().name)
assertEquals(1, deleted.size)
assertEquals("COL2", deleted.first().name)
assertEquals(1, modified.size)
assertEquals("COL3", modified.first().name)
}
@Test
fun testCreateNamespaceWithNetworkFailure() {
val namespace = "test_namespace"

View File

@@ -4,14 +4,19 @@
package io.airbyte.integrations.destination.snowflake.dataflow
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.command.DestinationStream.Descriptor
import io.airbyte.cdk.load.dataflow.aggregate.StoreKey
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.write.StreamStateStore
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRawRecordFormatter
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter
import io.mockk.every
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertEquals
@@ -22,27 +27,33 @@ internal class SnowflakeAggregateFactoryTest {
@Test
fun testCreatingAggregateWithRawBuffer() {
val descriptor = Descriptor(namespace = "namespace", name = "name")
val directLoadTableExecutionConfig =
DirectLoadTableExecutionConfig(
tableName =
TableName(
namespace = descriptor.namespace!!,
name = descriptor.name,
)
val tableName =
TableName(
namespace = descriptor.namespace!!,
name = descriptor.name,
)
val directLoadTableExecutionConfig = DirectLoadTableExecutionConfig(tableName = tableName)
val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name)
val streamStore = StreamStateStore<DirectLoadTableExecutionConfig>()
streamStore.put(descriptor, directLoadTableExecutionConfig)
streamStore.put(key, directLoadTableExecutionConfig)
val stream = mockk<DestinationStream>(relaxed = true)
val catalog = mockk<DestinationCatalog> { every { getStream(key) } returns stream }
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val snowflakeConfiguration =
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns true }
val snowflakeColumnUtils = mockk<SnowflakeColumnUtils>(relaxed = true)
val columnManager = SnowflakeColumnManager(snowflakeConfiguration)
val snowflakeRecordFormatter: SnowflakeRecordFormatter = SnowflakeRawRecordFormatter()
val factory =
SnowflakeAggregateFactory(
snowflakeClient = snowflakeClient,
streamStateStore = streamStore,
snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils,
catalog = catalog,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
)
val aggregate = factory.create(key)
assertNotNull(aggregate)
@@ -52,26 +63,33 @@ internal class SnowflakeAggregateFactoryTest {
@Test
fun testCreatingAggregateWithStagingBuffer() {
val descriptor = Descriptor(namespace = "namespace", name = "name")
val directLoadTableExecutionConfig =
DirectLoadTableExecutionConfig(
tableName =
TableName(
namespace = descriptor.namespace!!,
name = descriptor.name,
)
val tableName =
TableName(
namespace = descriptor.namespace!!,
name = descriptor.name,
)
val directLoadTableExecutionConfig = DirectLoadTableExecutionConfig(tableName = tableName)
val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name)
val streamStore = StreamStateStore<DirectLoadTableExecutionConfig>()
streamStore.put(descriptor, directLoadTableExecutionConfig)
streamStore.put(key, directLoadTableExecutionConfig)
val stream = mockk<DestinationStream>(relaxed = true)
val catalog = mockk<DestinationCatalog> { every { getStream(key) } returns stream }
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val snowflakeConfiguration = mockk<SnowflakeConfiguration>(relaxed = true)
val snowflakeColumnUtils = mockk<SnowflakeColumnUtils>(relaxed = true)
val snowflakeConfiguration =
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns false }
val columnManager = SnowflakeColumnManager(snowflakeConfiguration)
val snowflakeRecordFormatter: SnowflakeRecordFormatter = SnowflakeSchemaRecordFormatter()
val factory =
SnowflakeAggregateFactory(
snowflakeClient = snowflakeClient,
streamStateStore = streamStore,
snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils,
catalog = catalog,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
)
val aggregate = factory.create(key)
assertNotNull(aggregate)

View File

@@ -1,23 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
import io.mockk.every
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
internal class SnowflakeColumnNameGeneratorTest {
@Test
fun testGetColumnName() {
val column = "test-column"
val generator =
SnowflakeColumnNameGenerator(mockk { every { legacyRawTablesOnly } returns false })
val columnName = generator.getColumnName(column)
assertEquals(column.toSnowflakeCompatibleName(), columnName.displayName)
assertEquals(column.toSnowflakeCompatibleName(), columnName.canonicalName)
}
}

View File

@@ -1,72 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.mockk.every
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
internal class SnowflakeFinalTableNameGeneratorTest {
@Test
fun testGetTableNameWithInternalNamespace() {
val configuration =
mockk<SnowflakeConfiguration> {
every { internalTableSchema } returns "test-internal-namespace"
every { legacyRawTablesOnly } returns true
}
val generator = SnowflakeFinalTableNameGenerator(config = configuration)
val streamName = "test-stream-name"
val streamNamespace = "test-stream-namespace"
val streamDescriptor =
mockk<DestinationStream.Descriptor> {
every { namespace } returns streamNamespace
every { name } returns streamName
}
val tableName = generator.getTableName(streamDescriptor)
assertEquals("test-stream-namespace_raw__stream_test-stream-name", tableName.name)
assertEquals("test-internal-namespace", tableName.namespace)
}
@Test
fun testGetTableNameWithNamespace() {
val configuration =
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns false }
val generator = SnowflakeFinalTableNameGenerator(config = configuration)
val streamName = "test-stream-name"
val streamNamespace = "test-stream-namespace"
val streamDescriptor =
mockk<DestinationStream.Descriptor> {
every { namespace } returns streamNamespace
every { name } returns streamName
}
val tableName = generator.getTableName(streamDescriptor)
assertEquals("TEST-STREAM-NAME", tableName.name)
assertEquals("TEST-STREAM-NAMESPACE", tableName.namespace)
}
@Test
fun testGetTableNameWithDefaultNamespace() {
val defaultNamespace = "test-default-namespace"
val configuration =
mockk<SnowflakeConfiguration> {
every { schema } returns defaultNamespace
every { legacyRawTablesOnly } returns false
}
val generator = SnowflakeFinalTableNameGenerator(config = configuration)
val streamName = "test-stream-name"
val streamDescriptor =
mockk<DestinationStream.Descriptor> {
every { namespace } returns null
every { name } returns streamName
}
val tableName = generator.getTableName(streamDescriptor)
assertEquals("TEST-STREAM-NAME", tableName.name)
assertEquals("TEST-DEFAULT-NAMESPACE", tableName.namespace)
}
}

View File

@@ -1,27 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.CsvSource
internal class SnowflakeNameGeneratorsTest {
@ParameterizedTest
@CsvSource(
value =
[
"test-name,TEST-NAME",
"1-test-name,1-TEST-NAME",
"test-name!!!,TEST-NAME!!!",
"test\${name,TEST__NAME",
"test\"name,TEST\"\"NAME",
]
)
fun testToSnowflakeCompatibleName(name: String, expected: String) {
assertEquals(expected, name.toSnowflakeCompatibleName())
}
}

View File

@@ -1,358 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
import com.fasterxml.jackson.databind.JsonNode
import io.airbyte.cdk.load.data.ArrayType
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
import io.airbyte.cdk.load.data.BooleanType
import io.airbyte.cdk.load.data.DateType
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerType
import io.airbyte.cdk.load.data.NumberType
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
import io.airbyte.cdk.load.data.UnionType
import io.airbyte.cdk.load.data.UnknownType
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA
import io.airbyte.cdk.load.orchestration.db.ColumnNameGenerator
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.integrations.destination.snowflake.db.SnowflakeColumnNameGenerator
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.mockk.every
import io.mockk.mockk
import kotlin.collections.LinkedHashMap
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.CsvSource
internal class SnowflakeColumnUtilsTest {
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
private lateinit var snowflakeColumnNameGenerator: SnowflakeColumnNameGenerator
@BeforeEach
fun setup() {
snowflakeConfiguration = mockk(relaxed = true)
snowflakeColumnNameGenerator =
mockk(relaxed = true) {
every { getColumnName(any()) } answers
{
val displayName =
if (snowflakeConfiguration.legacyRawTablesOnly) firstArg<String>()
else firstArg<String>().toSnowflakeCompatibleName()
val canonicalName =
if (snowflakeConfiguration.legacyRawTablesOnly) firstArg<String>()
else firstArg<String>().toSnowflakeCompatibleName()
ColumnNameGenerator.ColumnName(
displayName = displayName,
canonicalName = canonicalName,
)
}
}
snowflakeColumnUtils =
SnowflakeColumnUtils(snowflakeConfiguration, snowflakeColumnNameGenerator)
}
@Test
fun testDefaultColumns() {
val expectedDefaultColumns = DEFAULT_COLUMNS
assertEquals(expectedDefaultColumns, snowflakeColumnUtils.defaultColumns())
}
@Test
fun testDefaultRawColumns() {
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val expectedDefaultColumns = DEFAULT_COLUMNS + RAW_COLUMNS
assertEquals(expectedDefaultColumns, snowflakeColumnUtils.defaultColumns())
}
@Test
fun testGetFormattedDefaultColumnNames() {
val expectedDefaultColumnNames =
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() }
val defaultColumnNames = snowflakeColumnUtils.getFormattedDefaultColumnNames()
assertEquals(expectedDefaultColumnNames, defaultColumnNames)
}
@Test
fun testGetFormattedDefaultColumnNamesQuoted() {
val expectedDefaultColumnNames =
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName().quote() }
val defaultColumnNames = snowflakeColumnUtils.getFormattedDefaultColumnNames(true)
assertEquals(expectedDefaultColumnNames, defaultColumnNames)
}
@Test
fun testGetColumnName() {
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val columnNames = snowflakeColumnUtils.getColumnNames(columnNameMapping)
val expectedColumnNames =
(DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() } + listOf("actual"))
.joinToString(",") { it.quote() }
assertEquals(expectedColumnNames, columnNames)
}
@Test
fun testGetRawColumnName() {
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val columnNames = snowflakeColumnUtils.getColumnNames(columnNameMapping)
val expectedColumnNames =
(DEFAULT_COLUMNS.map { it.columnName } + RAW_COLUMNS.map { it.columnName })
.joinToString(",") { it.quote() }
assertEquals(expectedColumnNames, columnNames)
}
@Test
fun testGetRawFormattedColumnNames() {
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val schemaColumns =
mapOf(
"column_one" to FieldType(StringType, true),
"column_two" to FieldType(IntegerType, true),
"original" to FieldType(StringType, true),
CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true)
)
val expectedColumnNames =
DEFAULT_COLUMNS.map { it.columnName.quote() } +
RAW_COLUMNS.map { it.columnName.quote() }
val columnNames =
snowflakeColumnUtils.getFormattedColumnNames(
columns = schemaColumns,
columnNameMapping = columnNameMapping
)
assertEquals(expectedColumnNames.size, columnNames.size)
assertEquals(expectedColumnNames.sorted(), columnNames.sorted())
}
@Test
fun testGetFormattedColumnNames() {
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val schemaColumns =
mapOf(
"column_one" to FieldType(StringType, true),
"column_two" to FieldType(IntegerType, true),
"original" to FieldType(StringType, true),
CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true)
)
val expectedColumnNames =
listOf(
"actual",
"column_one",
"column_two",
CDC_DELETED_AT_COLUMN,
)
.map { it.quote() } +
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName().quote() }
val columnNames =
snowflakeColumnUtils.getFormattedColumnNames(
columns = schemaColumns,
columnNameMapping = columnNameMapping
)
assertEquals(expectedColumnNames.size, columnNames.size)
assertEquals(expectedColumnNames.sorted(), columnNames.sorted())
}
@Test
fun testGetFormattedColumnNamesNoQuotes() {
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val schemaColumns =
mapOf(
"column_one" to FieldType(StringType, true),
"column_two" to FieldType(IntegerType, true),
"original" to FieldType(StringType, true),
CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true)
)
val expectedColumnNames =
listOf(
"actual",
"column_one",
"column_two",
CDC_DELETED_AT_COLUMN,
) + DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() }
val columnNames =
snowflakeColumnUtils.getFormattedColumnNames(
columns = schemaColumns,
columnNameMapping = columnNameMapping,
quote = false
)
assertEquals(expectedColumnNames.size, columnNames.size)
assertEquals(expectedColumnNames.sorted(), columnNames.sorted())
}
@Test
fun testGeneratingRawTableColumnsAndTypesNoColumnMapping() {
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val columns =
snowflakeColumnUtils.columnsAndTypes(
columns = emptyMap(),
columnNameMapping = ColumnNameMapping(emptyMap())
)
assertEquals(DEFAULT_COLUMNS.size + RAW_COLUMNS.size, columns.size)
assertEquals(
"${SnowflakeDataType.VARIANT.typeName} $NOT_NULL",
columns.find { it.columnName == RAW_DATA_COLUMN.columnName }?.columnType
)
}
@Test
fun testGeneratingColumnsAndTypesNoColumnMapping() {
val columnName = "test-column"
val fieldType = FieldType(StringType, false)
val declaredColumns = mapOf(columnName to fieldType)
val columns =
snowflakeColumnUtils.columnsAndTypes(
columns = declaredColumns,
columnNameMapping = ColumnNameMapping(emptyMap())
)
assertEquals(DEFAULT_COLUMNS.size + 1, columns.size)
assertEquals(
"${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL",
columns.find { it.columnName == columnName }?.columnType
)
}
@Test
fun testGeneratingColumnsAndTypesWithColumnMapping() {
val columnName = "test-column"
val mappedColumnName = "mapped-column-name"
val fieldType = FieldType(StringType, false)
val declaredColumns = mapOf(columnName to fieldType)
val columnNameMapping = ColumnNameMapping(mapOf(columnName to mappedColumnName))
val columns =
snowflakeColumnUtils.columnsAndTypes(
columns = declaredColumns,
columnNameMapping = columnNameMapping
)
assertEquals(DEFAULT_COLUMNS.size + 1, columns.size)
assertEquals(
"${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL",
columns.find { it.columnName == mappedColumnName }?.columnType
)
}
@Test
fun testToDialectType() {
assertEquals(
SnowflakeDataType.BOOLEAN.typeName,
snowflakeColumnUtils.toDialectType(BooleanType)
)
assertEquals(SnowflakeDataType.DATE.typeName, snowflakeColumnUtils.toDialectType(DateType))
assertEquals(
SnowflakeDataType.NUMBER.typeName,
snowflakeColumnUtils.toDialectType(IntegerType)
)
assertEquals(
SnowflakeDataType.FLOAT.typeName,
snowflakeColumnUtils.toDialectType(NumberType)
)
assertEquals(
SnowflakeDataType.VARCHAR.typeName,
snowflakeColumnUtils.toDialectType(StringType)
)
assertEquals(
SnowflakeDataType.VARCHAR.typeName,
snowflakeColumnUtils.toDialectType(TimeTypeWithTimezone)
)
assertEquals(
SnowflakeDataType.TIME.typeName,
snowflakeColumnUtils.toDialectType(TimeTypeWithoutTimezone)
)
assertEquals(
SnowflakeDataType.TIMESTAMP_TZ.typeName,
snowflakeColumnUtils.toDialectType(TimestampTypeWithTimezone)
)
assertEquals(
SnowflakeDataType.TIMESTAMP_NTZ.typeName,
snowflakeColumnUtils.toDialectType(TimestampTypeWithoutTimezone)
)
assertEquals(
SnowflakeDataType.ARRAY.typeName,
snowflakeColumnUtils.toDialectType(ArrayType(items = FieldType(StringType, false)))
)
assertEquals(
SnowflakeDataType.ARRAY.typeName,
snowflakeColumnUtils.toDialectType(ArrayTypeWithoutSchema)
)
assertEquals(
SnowflakeDataType.OBJECT.typeName,
snowflakeColumnUtils.toDialectType(
ObjectType(
properties = LinkedHashMap(),
additionalProperties = false,
)
)
)
assertEquals(
SnowflakeDataType.OBJECT.typeName,
snowflakeColumnUtils.toDialectType(ObjectTypeWithEmptySchema)
)
assertEquals(
SnowflakeDataType.OBJECT.typeName,
snowflakeColumnUtils.toDialectType(ObjectTypeWithoutSchema)
)
assertEquals(
SnowflakeDataType.VARIANT.typeName,
snowflakeColumnUtils.toDialectType(
UnionType(
options = setOf(StringType),
isLegacyUnion = true,
)
)
)
assertEquals(
SnowflakeDataType.VARIANT.typeName,
snowflakeColumnUtils.toDialectType(
UnionType(
options = emptySet(),
isLegacyUnion = false,
)
)
)
assertEquals(
SnowflakeDataType.VARIANT.typeName,
snowflakeColumnUtils.toDialectType(UnknownType(schema = mockk<JsonNode>()))
)
}
@ParameterizedTest
@CsvSource(
value =
[
"$COLUMN_NAME_DATA, true, \"$COLUMN_NAME_DATA\"",
"some-other_Column, true, \"SOME-OTHER_COLUMN\"",
"$COLUMN_NAME_DATA, false, $COLUMN_NAME_DATA",
"some-other_Column, false, SOME-OTHER_COLUMN",
"$COLUMN_NAME_DATA, true, \"$COLUMN_NAME_DATA\"",
"some-other_Column, true, \"SOME-OTHER_COLUMN\"",
]
)
fun testFormatColumnName(columnName: String, quote: Boolean, expectedFormattedName: String) {
assertEquals(
expectedFormattedName,
snowflakeColumnUtils.formatColumnName(columnName, quote)
)
}
}

View File

@@ -1,97 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
import io.airbyte.cdk.load.table.TableName
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.mockk.every
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
internal class SnowflakeSqlNameUtilsTest {
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
private lateinit var snowflakeSqlNameUtils: SnowflakeSqlNameUtils
@BeforeEach
fun setUp() {
snowflakeConfiguration = mockk(relaxed = true)
snowflakeSqlNameUtils =
SnowflakeSqlNameUtils(snowflakeConfiguration = snowflakeConfiguration)
}
@Test
fun testFullyQualifiedName() {
val databaseName = "test-database"
val namespace = "test-namespace"
val name = "test=name"
val tableName = TableName(namespace = namespace, name = name)
every { snowflakeConfiguration.database } returns databaseName
val expectedName =
snowflakeSqlNameUtils.combineParts(
listOf(
databaseName.toSnowflakeCompatibleName(),
tableName.namespace,
tableName.name
)
)
val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedName(tableName)
assertEquals(expectedName, fullyQualifiedName)
}
@Test
fun testFullyQualifiedNamespace() {
val databaseName = "test-database"
val namespace = "test-namespace"
every { snowflakeConfiguration.database } returns databaseName
val fullyQualifiedNamespace = snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)
assertEquals("\"TEST-DATABASE\".\"test-namespace\"", fullyQualifiedNamespace)
}
@Test
fun testFullyQualifiedStageName() {
val databaseName = "test-database"
val namespace = "test-namespace"
val name = "test=name"
val tableName = TableName(namespace = namespace, name = name)
every { snowflakeConfiguration.database } returns databaseName
val expectedName =
snowflakeSqlNameUtils.combineParts(
listOf(
databaseName.toSnowflakeCompatibleName(),
namespace,
"$STAGE_NAME_PREFIX$name"
)
)
val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName)
assertEquals(expectedName, fullyQualifiedName)
}
@Test
fun testFullyQualifiedStageNameWithEscape() {
val databaseName = "test-database"
val namespace = "test-namespace"
val name = "test=\"\"\'name"
val tableName = TableName(namespace = namespace, name = name)
every { snowflakeConfiguration.database } returns databaseName
val expectedName =
snowflakeSqlNameUtils.combineParts(
listOf(
databaseName.toSnowflakeCompatibleName(),
namespace,
"$STAGE_NAME_PREFIX${sqlEscape(name)}"
)
)
val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true)
assertEquals(expectedName, fullyQualifiedName)
}
}

View File

@@ -6,20 +6,21 @@ package io.airbyte.integrations.destination.snowflake.write
import io.airbyte.cdk.SystemErrorException
import io.airbyte.cdk.load.command.Append
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.orchestration.db.DatabaseInitialStatusGatherer
import io.airbyte.cdk.load.orchestration.db.TableNames
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadInitialStatus
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendStreamLoader
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendTruncateStreamLoader
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableStatus
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableNameInfo
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.schema.model.TableNames
import io.airbyte.cdk.load.table.DatabaseInitialStatusGatherer
import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.cdk.load.table.directload.DirectLoadInitialStatus
import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendStreamLoader
import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendTruncateStreamLoader
import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.table.directload.DirectLoadTableStatus
import io.airbyte.cdk.load.write.StreamStateStore
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.every
@@ -34,55 +35,93 @@ internal class SnowflakeWriterTest {
@Test
fun testSetup() {
val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
val stream = mockk<DestinationStream>()
val tableInfo =
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream =
mockk<DestinationStream> {
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
every { importType } returns Append
}
val catalog = DestinationCatalog(listOf(stream))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns emptyMap()
coEvery { gatherInitialStatus() } returns
mapOf(
stream to
DirectLoadInitialStatus(
realTable = DirectLoadTableStatus(false),
tempTable = null
)
)
}
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer =
SnowflakeWriter(
names = catalog,
catalog = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(relaxed = true),
snowflakeConfiguration =
mockk(relaxed = true) {
every { internalTableSchema } returns "internal_schema"
},
)
runBlocking { writer.setup() }
coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) }
coVerify(exactly = 1) { stateGatherer.gatherInitialStatus(catalog) }
coVerify(exactly = 1) { stateGatherer.gatherInitialStatus() }
}
@Test
fun testCreateStreamLoaderFirstGeneration() {
val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream =
mockk<DestinationStream> {
every { minimumGenerationId } returns 0L
every { generationId } returns 0L
every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
}
val tableInfo =
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val catalog = DestinationCatalog(listOf(stream))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns
coEvery { gatherInitialStatus() } returns
mapOf(
stream to
DirectLoadInitialStatus(
@@ -93,14 +132,18 @@ internal class SnowflakeWriterTest {
}
val tempTableNameGenerator =
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer =
SnowflakeWriter(
names = catalog,
catalog = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient,
tempTableNameGenerator = tempTableNameGenerator,
snowflakeConfiguration = mockk(relaxed = true),
snowflakeConfiguration =
mockk(relaxed = true) {
every { internalTableSchema } returns "internal_schema"
},
)
runBlocking {
@@ -113,23 +156,35 @@ internal class SnowflakeWriterTest {
@Test
fun testCreateStreamLoaderNotFirstGeneration() {
val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream =
mockk<DestinationStream> {
every { minimumGenerationId } returns 1L
every { generationId } returns 1L
every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
}
val tableInfo =
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val catalog = DestinationCatalog(listOf(stream))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns
coEvery { gatherInitialStatus() } returns
mapOf(
stream to
DirectLoadInitialStatus(
@@ -140,14 +195,18 @@ internal class SnowflakeWriterTest {
}
val tempTableNameGenerator =
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer =
SnowflakeWriter(
names = catalog,
catalog = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient,
tempTableNameGenerator = tempTableNameGenerator,
snowflakeConfiguration = mockk(relaxed = true),
snowflakeConfiguration =
mockk(relaxed = true) {
every { internalTableSchema } returns "internal_schema"
},
)
runBlocking {
@@ -160,22 +219,35 @@ internal class SnowflakeWriterTest {
@Test
fun testCreateStreamLoaderHybrid() {
val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream =
mockk<DestinationStream> {
every { minimumGenerationId } returns 1L
every { generationId } returns 2L
every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
}
val tableInfo =
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val catalog = DestinationCatalog(listOf(stream))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns
coEvery { gatherInitialStatus() } returns
mapOf(
stream to
DirectLoadInitialStatus(
@@ -186,14 +258,18 @@ internal class SnowflakeWriterTest {
}
val tempTableNameGenerator =
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer =
SnowflakeWriter(
names = catalog,
catalog = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient,
tempTableNameGenerator = tempTableNameGenerator,
snowflakeConfiguration = mockk(relaxed = true),
snowflakeConfiguration =
mockk(relaxed = true) {
every { internalTableSchema } returns "internal_schema"
},
)
runBlocking {
@@ -203,169 +279,126 @@ internal class SnowflakeWriterTest {
}
@Test
fun testSetupWithNamespaceCreationFailure() {
val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
val stream = mockk<DestinationStream>()
val tableInfo =
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val snowflakeClient = mockk<SnowflakeAirbyteClient>()
val stateGatherer = mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>>()
val writer =
SnowflakeWriter(
names = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(),
)
// Simulate network failure during namespace creation
coEvery {
snowflakeClient.createNamespace(tableName.namespace.toSnowflakeCompatibleName())
} throws RuntimeException("Network connection failed")
assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } }
}
@Test
fun testSetupWithInitialStatusGatheringFailure() {
val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
val stream = mockk<DestinationStream>()
val tableInfo =
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer = mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>>()
val writer =
SnowflakeWriter(
names = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(),
)
// Simulate failure while gathering initial status
coEvery { stateGatherer.gatherInitialStatus(catalog) } throws
RuntimeException("Failed to query table status")
assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } }
// Verify namespace creation was still attempted
coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) }
}
@Test
fun testCreateStreamLoaderWithMissingInitialStatus() {
val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
fun testCreateStreamLoaderNamespaceLegacy() {
val namespace = "test-namespace"
val name = "test-name"
val tableName = TableName(namespace = namespace, name = name)
val tempTableName = TableName(namespace = namespace, name = "${name}-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream =
mockk<DestinationStream> {
every { minimumGenerationId } returns 0L
every { generationId } returns 0L
every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
}
val missingStream =
mockk<DestinationStream> {
every { minimumGenerationId } returns 0L
every { generationId } returns 0L
}
val tableInfo =
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val catalog = DestinationCatalog(listOf(stream))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns
coEvery { gatherInitialStatus() } returns
mapOf(
stream to
DirectLoadInitialStatus(
realTable = DirectLoadTableStatus(false),
tempTable = null,
tempTable = null
)
)
}
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer =
SnowflakeWriter(
names = catalog,
catalog = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(relaxed = true),
snowflakeConfiguration =
mockk(relaxed = true) {
every { legacyRawTablesOnly } returns true
every { internalTableSchema } returns "internal_schema"
},
)
runBlocking {
writer.setup()
// Try to create loader for a stream that wasn't in initial status
assertThrows(NullPointerException::class.java) {
writer.createStreamLoader(missingStream)
runBlocking { writer.setup() }
coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) }
}
@Test
fun testCreateStreamLoaderNamespaceNonLegacy() {
val namespace = "test-namespace"
val name = "test-name"
val tableName = TableName(namespace = namespace, name = name)
val tempTableName = TableName(namespace = namespace, name = "${name}-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream =
mockk<DestinationStream> {
every { minimumGenerationId } returns 0L
every { generationId } returns 0L
every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
}
}
}
@Test
fun testCreateStreamLoaderWithNullFinalTableName() {
// TableNames constructor throws IllegalStateException when both names are null
assertThrows(IllegalStateException::class.java) {
TableNames(rawTableName = null, finalTableName = null)
}
}
@Test
fun testSetupWithMultipleNamespaceFailuresPartial() {
val tableName1 = TableName(namespace = "namespace1", name = "table1")
val tableName2 = TableName(namespace = "namespace2", name = "table2")
val tableNames1 = TableNames(rawTableName = null, finalTableName = tableName1)
val tableNames2 = TableNames(rawTableName = null, finalTableName = tableName2)
val stream1 = mockk<DestinationStream>()
val stream2 = mockk<DestinationStream>()
val tableInfo1 =
TableNameInfo(
tableNames = tableNames1,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val tableInfo2 =
TableNameInfo(
tableNames = tableNames2,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream1 to tableInfo1, stream2 to tableInfo2))
val snowflakeClient = mockk<SnowflakeAirbyteClient>()
val stateGatherer = mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>>()
val catalog = DestinationCatalog(listOf(stream))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus() } returns
mapOf(
stream to
DirectLoadInitialStatus(
realTable = DirectLoadTableStatus(false),
tempTable = null
)
)
}
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer =
SnowflakeWriter(
names = catalog,
catalog = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(),
snowflakeConfiguration =
mockk(relaxed = true) {
every { legacyRawTablesOnly } returns false
every { internalTableSchema } returns "internal_schema"
},
)
// First namespace succeeds, second fails (namespaces are uppercased by
// toSnowflakeCompatibleName)
coEvery { snowflakeClient.createNamespace("namespace1") } returns Unit
coEvery { snowflakeClient.createNamespace("namespace2") } throws
RuntimeException("Connection timeout")
runBlocking { writer.setup() }
assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } }
// Verify both namespace creations were attempted
coVerify(exactly = 1) { snowflakeClient.createNamespace("namespace1") }
coVerify(exactly = 1) { snowflakeClient.createNamespace("namespace2") }
coVerify(exactly = 1) { snowflakeClient.createNamespace(namespace) }
}
}

View File

@@ -4,218 +4,220 @@
package io.airbyte.integrations.destination.snowflake.write.load
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerValue
import io.airbyte.cdk.load.data.NullValue
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.mockk.coVerify
import io.mockk.every
import io.mockk.mockk
import java.io.BufferedReader
import java.io.File
import java.io.InputStreamReader
import java.util.zip.GZIPInputStream
import kotlin.io.path.exists
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNotNull
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
internal class SnowflakeInsertBufferTest {
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
private lateinit var columnManager: SnowflakeColumnManager
private lateinit var columnSchema: ColumnSchema
private lateinit var snowflakeRecordFormatter: SnowflakeRecordFormatter
@BeforeEach
fun setUp() {
snowflakeConfiguration = mockk(relaxed = true)
snowflakeColumnUtils = mockk(relaxed = true)
snowflakeRecordFormatter = SnowflakeSchemaRecordFormatter()
columnManager =
mockk(relaxed = true) {
every { getMetaColumns() } returns
linkedMapOf(
"_AIRBYTE_RAW_ID" to ColumnType("VARCHAR", false),
"_AIRBYTE_EXTRACTED_AT" to ColumnType("TIMESTAMP_TZ", false),
"_AIRBYTE_META" to ColumnType("VARIANT", false),
"_AIRBYTE_GENERATION_ID" to ColumnType("NUMBER", true)
)
every { getTableColumnNames(any()) } returns
listOf(
"_AIRBYTE_RAW_ID",
"_AIRBYTE_EXTRACTED_AT",
"_AIRBYTE_META",
"_AIRBYTE_GENERATION_ID",
"columnName"
)
}
}
@Test
fun testAccumulate() {
val tableName = mockk<TableName>(relaxed = true)
val tableName = TableName(namespace = "test", name = "table")
val column = "columnName"
val columns = linkedMapOf(column to "NUMBER(38,0)")
columnSchema =
ColumnSchema(
inputToFinalColumnNames = mapOf(column to column.uppercase()),
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord(column)
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
columnSchema = columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
)
buffer.accumulate(record)
assertEquals(true, buffer.csvFilePath?.exists())
assertEquals(0, buffer.recordCount)
runBlocking { buffer.accumulate(record) }
assertEquals(1, buffer.recordCount)
}
@Test
fun testAccumulateRaw() {
val tableName = mockk<TableName>(relaxed = true)
fun testFlushToStaging() {
val tableName = TableName(namespace = "test", name = "table")
val column = "columnName"
val columns = linkedMapOf(column to "NUMBER(38,0)")
columnSchema =
ColumnSchema(
inputToFinalColumnNames = mapOf(column to column.uppercase()),
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord(column)
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
columnSchema = columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
)
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
buffer.accumulate(record)
assertEquals(true, buffer.csvFilePath?.exists())
assertEquals(1, buffer.recordCount)
}
@Test
fun testFlush() {
val tableName = mockk<TableName>(relaxed = true)
val column = "columnName"
val columns = linkedMapOf(column to "NUMBER(38,0)")
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord(column)
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
val expectedColumnNames =
listOf(
"_AIRBYTE_RAW_ID",
"_AIRBYTE_EXTRACTED_AT",
"_AIRBYTE_META",
"_AIRBYTE_GENERATION_ID",
"columnName"
)
runBlocking {
buffer.accumulate(record)
buffer.flush()
}
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
coVerify(exactly = 1) {
snowflakeAirbyteClient.copyFromStage(
tableName,
match { it.endsWith("$CSV_FILE_EXTENSION$FILE_SUFFIX") }
)
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
coVerify(exactly = 1) {
snowflakeAirbyteClient.copyFromStage(tableName, any(), expectedColumnNames)
}
}
}
@Test
fun testFlushRaw() {
val tableName = mockk<TableName>(relaxed = true)
fun testFlushToNoStaging() {
val tableName = TableName(namespace = "test", name = "table")
val column = "columnName"
val columns = linkedMapOf(column to "NUMBER(38,0)")
columnSchema =
ColumnSchema(
inputToFinalColumnNames = mapOf(column to column.uppercase()),
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val record = createRecord(column)
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
columnSchema = columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
flushLimit = 1,
)
val expectedColumnNames =
listOf(
"_AIRBYTE_RAW_ID",
"_AIRBYTE_EXTRACTED_AT",
"_AIRBYTE_META",
"_AIRBYTE_GENERATION_ID",
"columnName"
)
runBlocking {
buffer.accumulate(record)
buffer.flush()
// In legacy raw mode, it still uses staging
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
coVerify(exactly = 1) {
snowflakeAirbyteClient.copyFromStage(tableName, any(), expectedColumnNames)
}
}
}
@Test
fun testFileCreation() {
val tableName = TableName(namespace = "test", name = "table")
val column = "columnName"
columnSchema =
ColumnSchema(
inputToFinalColumnNames = mapOf(column to column.uppercase()),
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord(column)
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
columnSchema = columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
)
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
runBlocking {
buffer.accumulate(record)
buffer.flush()
}
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
coVerify(exactly = 1) {
snowflakeAirbyteClient.copyFromStage(
tableName,
match { it.endsWith("$CSV_FILE_EXTENSION$FILE_SUFFIX") }
)
}
}
@Test
fun testMissingFields() {
val tableName = mockk<TableName>(relaxed = true)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord("COLUMN1")
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
columns = linkedMapOf("COLUMN1" to "NUMBER(38,0)", "COLUMN2" to "NUMBER(38,0)"),
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
)
runBlocking {
buffer.accumulate(record)
buffer.csvWriter?.flush()
// The csvFilePath is internal, we can access it for testing
val filepath = buffer.csvFilePath
assertNotNull(filepath)
val file = filepath!!.toFile()
assert(file.exists())
// Close the writer to ensure all data is flushed
buffer.csvWriter?.close()
assertEquals(
"test-value$CSV_FIELD_SEPARATOR$CSV_LINE_DELIMITER",
readFromCsvFile(buffer.csvFilePath!!.toFile())
)
val lines = mutableListOf<String>()
GZIPInputStream(file.inputStream()).use { gzip ->
BufferedReader(InputStreamReader(gzip)).use { bufferedReader ->
bufferedReader.forEachLine { line -> lines.add(line) }
}
}
assertEquals(1, lines.size)
file.delete()
}
}
@Test
fun testMissingFieldsRaw() {
val tableName = mockk<TableName>(relaxed = true)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord("COLUMN1")
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
columns = linkedMapOf("COLUMN1" to "NUMBER(38,0)", "COLUMN2" to "NUMBER(38,0)"),
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
)
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
runBlocking {
buffer.accumulate(record)
buffer.csvWriter?.flush()
buffer.csvWriter?.close()
assertEquals(
"test-value$CSV_FIELD_SEPARATOR$CSV_LINE_DELIMITER",
readFromCsvFile(buffer.csvFilePath!!.toFile())
)
}
}
private fun readFromCsvFile(file: File) =
GZIPInputStream(file.inputStream()).use { input ->
val reader = BufferedReader(InputStreamReader(input))
reader.readText()
}
private fun createRecord(columnName: String) =
mapOf(
columnName to AirbyteValue.from("test-value"),
Meta.COLUMN_NAME_AB_EXTRACTED_AT to IntegerValue(System.currentTimeMillis()),
Meta.COLUMN_NAME_AB_RAW_ID to StringValue("raw-id"),
Meta.COLUMN_NAME_AB_GENERATION_ID to IntegerValue(1223),
Meta.COLUMN_NAME_AB_META to StringValue("{\"changes\":[],\"syncId\":43}"),
"${columnName}Null" to NullValue
private fun createRecord(column: String): Map<String, AirbyteValue> {
return mapOf(
column to IntegerValue(value = 42),
Meta.COLUMN_NAME_AB_GENERATION_ID to NullValue,
Meta.COLUMN_NAME_AB_RAW_ID to StringValue("raw-id-1"),
Meta.COLUMN_NAME_AB_EXTRACTED_AT to IntegerValue(1234567890),
Meta.COLUMN_NAME_AB_META to StringValue("meta-data-foo"),
)
}
}

View File

@@ -15,12 +15,8 @@ import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.mockk.every
import io.mockk.mockk
import kotlin.collections.plus
import io.airbyte.cdk.load.schema.model.ColumnSchema
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
private val AIRBYTE_COLUMN_TYPES_MAP =
@@ -58,28 +54,16 @@ private fun createExpected(
internal class SnowflakeRawRecordFormatterTest {
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
@BeforeEach
fun setup() {
snowflakeColumnUtils = mockk {
every { getFormattedDefaultColumnNames(any()) } returns
AIRBYTE_COLUMN_TYPES_MAP.keys.toList()
}
}
@Test
fun testFormatting() {
val columnName = "test-column-name"
val columnValue = "test-column-value"
val columns = AIRBYTE_COLUMN_TYPES_MAP
val record = createRecord(columnName = columnName, columnValue = columnValue)
val formatter =
SnowflakeRawRecordFormatter(
columns = AIRBYTE_COLUMN_TYPES_MAP,
snowflakeColumnUtils = snowflakeColumnUtils
)
val formattedValue = formatter.format(record)
val formatter = SnowflakeRawRecordFormatter()
// RawRecordFormatter doesn't use columnSchema but still needs one per interface
val dummyColumnSchema = ColumnSchema(emptyMap(), emptyMap(), emptyMap())
val formattedValue = formatter.format(record, dummyColumnSchema)
val expectedValue =
createExpected(
record = record,
@@ -93,33 +77,28 @@ internal class SnowflakeRawRecordFormatterTest {
fun testFormattingMigratedFromPreviousVersion() {
val columnName = "test-column-name"
val columnValue = "test-column-value"
val columnsMap =
linkedMapOf(
COLUMN_NAME_AB_EXTRACTED_AT to "TIMESTAMP_TZ(9)",
COLUMN_NAME_AB_LOADED_AT to "TIMESTAMP_TZ(9)",
COLUMN_NAME_AB_META to "VARIANT",
COLUMN_NAME_DATA to "VARIANT",
COLUMN_NAME_AB_RAW_ID to "VARCHAR(16777216)",
COLUMN_NAME_AB_GENERATION_ID to "NUMBER(38,0)",
)
val record = createRecord(columnName = columnName, columnValue = columnValue)
val formatter =
SnowflakeRawRecordFormatter(
columns = columnsMap,
snowflakeColumnUtils = snowflakeColumnUtils
)
val formattedValue = formatter.format(record)
val formatter = SnowflakeRawRecordFormatter()
// RawRecordFormatter doesn't use columnSchema but still needs one per interface
val dummyColumnSchema = ColumnSchema(emptyMap(), emptyMap(), emptyMap())
val formattedValue = formatter.format(record, dummyColumnSchema)
// The formatter outputs in a fixed order regardless of input column order:
// 1. AB_RAW_ID
// 2. AB_EXTRACTED_AT
// 3. AB_META
// 4. AB_GENERATION_ID
// 5. AB_LOADED_AT
// 6. DATA (JSON with remaining columns)
val expectedValue =
createExpected(
record = record,
columns = columnsMap,
airbyteColumns = columnsMap.keys.toList(),
)
.toMutableList()
expectedValue.add(
columnsMap.keys.indexOf(COLUMN_NAME_DATA),
"{\"$columnName\":\"$columnValue\"}"
)
listOf(
record[COLUMN_NAME_AB_RAW_ID]!!.toCsvValue(),
record[COLUMN_NAME_AB_EXTRACTED_AT]!!.toCsvValue(),
record[COLUMN_NAME_AB_META]!!.toCsvValue(),
record[COLUMN_NAME_AB_GENERATION_ID]!!.toCsvValue(),
record[COLUMN_NAME_AB_LOADED_AT]!!.toCsvValue(),
"{\"$columnName\":\"$columnValue\"}"
)
assertEquals(expectedValue, formattedValue)
}
}

View File

@@ -4,66 +4,58 @@
package io.airbyte.integrations.destination.snowflake.write.load
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerValue
import io.airbyte.cdk.load.data.NullValue
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.data.csv.toCsvValue
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.mockk.every
import io.mockk.mockk
import java.util.AbstractMap
import kotlin.collections.component1
import kotlin.collections.component2
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
private val AIRBYTE_COLUMN_TYPES_MAP =
linkedMapOf(
COLUMN_NAME_AB_RAW_ID to "VARCHAR(16777216)",
COLUMN_NAME_AB_EXTRACTED_AT to "TIMESTAMP_TZ(9)",
COLUMN_NAME_AB_META to "VARIANT",
COLUMN_NAME_AB_GENERATION_ID to "NUMBER(38,0)",
)
.mapKeys { it.key.toSnowflakeCompatibleName() }
internal class SnowflakeSchemaRecordFormatterTest {
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
private fun createColumnSchema(userColumns: Map<String, String>): ColumnSchema {
val finalSchema = linkedMapOf<String, ColumnType>()
val inputToFinalColumnNames = mutableMapOf<String, String>()
val inputSchema = mutableMapOf<String, FieldType>()
@BeforeEach
fun setup() {
snowflakeColumnUtils = mockk {
every { getFormattedDefaultColumnNames(any()) } returns
AIRBYTE_COLUMN_TYPES_MAP.keys.toList()
// Add user columns
userColumns.forEach { (name, type) ->
val finalName = name.toSnowflakeCompatibleName()
finalSchema[finalName] = ColumnType(type, true)
inputToFinalColumnNames[name] = finalName
inputSchema[name] = FieldType(StringType, nullable = true)
}
return ColumnSchema(
inputToFinalColumnNames = inputToFinalColumnNames,
finalSchema = finalSchema,
inputSchema = inputSchema
)
}
@Test
fun testFormatting() {
val columnName = "test-column-name"
val columnValue = "test-column-value"
val columns =
(AIRBYTE_COLUMN_TYPES_MAP + linkedMapOf(columnName to "VARCHAR(16777216)")).mapKeys {
it.key.toSnowflakeCompatibleName()
}
val userColumns = mapOf(columnName to "VARCHAR(16777216)")
val columnSchema = createColumnSchema(userColumns)
val record = createRecord(columnName, columnValue)
val formatter =
SnowflakeSchemaRecordFormatter(
columns = columns as LinkedHashMap<String, String>,
snowflakeColumnUtils = snowflakeColumnUtils
)
val formattedValue = formatter.format(record)
val formatter = SnowflakeSchemaRecordFormatter()
val formattedValue = formatter.format(record, columnSchema)
val expectedValue =
createExpected(
record = record,
columns = columns,
columnSchema = columnSchema,
)
assertEquals(expectedValue, formattedValue)
}
@@ -72,21 +64,15 @@ internal class SnowflakeSchemaRecordFormatterTest {
fun testFormattingVariant() {
val columnName = "test-column-name"
val columnValue = "{\"test\": \"test-value\"}"
val columns =
(AIRBYTE_COLUMN_TYPES_MAP + linkedMapOf(columnName to "VARIANT")).mapKeys {
it.key.toSnowflakeCompatibleName()
}
val userColumns = mapOf(columnName to "VARIANT")
val columnSchema = createColumnSchema(userColumns)
val record = createRecord(columnName, columnValue)
val formatter =
SnowflakeSchemaRecordFormatter(
columns = columns as LinkedHashMap<String, String>,
snowflakeColumnUtils = snowflakeColumnUtils
)
val formattedValue = formatter.format(record)
val formatter = SnowflakeSchemaRecordFormatter()
val formattedValue = formatter.format(record, columnSchema)
val expectedValue =
createExpected(
record = record,
columns = columns,
columnSchema = columnSchema,
)
assertEquals(expectedValue, formattedValue)
}
@@ -95,23 +81,16 @@ internal class SnowflakeSchemaRecordFormatterTest {
fun testFormattingMissingColumn() {
val columnName = "test-column-name"
val columnValue = "test-column-value"
val columns =
AIRBYTE_COLUMN_TYPES_MAP +
linkedMapOf(
columnName to "VARCHAR(16777216)",
"missing-column" to "VARCHAR(16777216)"
)
val userColumns =
mapOf(columnName to "VARCHAR(16777216)", "missing-column" to "VARCHAR(16777216)")
val columnSchema = createColumnSchema(userColumns)
val record = createRecord(columnName, columnValue)
val formatter =
SnowflakeSchemaRecordFormatter(
columns = columns as LinkedHashMap<String, String>,
snowflakeColumnUtils = snowflakeColumnUtils
)
val formattedValue = formatter.format(record)
val formatter = SnowflakeSchemaRecordFormatter()
val formattedValue = formatter.format(record, columnSchema)
val expectedValue =
createExpected(
record = record,
columns = columns,
columnSchema = columnSchema,
filterMissing = false,
)
assertEquals(expectedValue, formattedValue)
@@ -128,16 +107,37 @@ internal class SnowflakeSchemaRecordFormatterTest {
private fun createExpected(
record: Map<String, AirbyteValue>,
columns: Map<String, String>,
columnSchema: ColumnSchema,
filterMissing: Boolean = true,
) =
record.entries
.associate { entry -> entry.key.toSnowflakeCompatibleName() to entry.value }
.map { entry -> AbstractMap.SimpleEntry(entry.key, entry.value.toCsvValue()) }
.sortedBy { entry ->
if (columns.keys.indexOf(entry.key) > -1) columns.keys.indexOf(entry.key)
else Int.MAX_VALUE
): List<Any> {
val columns = columnSchema.finalSchema.keys.toList()
val result = mutableListOf<Any>()
// Add meta columns first in the expected order
result.add(record[COLUMN_NAME_AB_RAW_ID]?.toCsvValue() ?: "")
result.add(record[COLUMN_NAME_AB_EXTRACTED_AT]?.toCsvValue() ?: "")
result.add(record[COLUMN_NAME_AB_META]?.toCsvValue() ?: "")
result.add(record[COLUMN_NAME_AB_GENERATION_ID]?.toCsvValue() ?: "")
// Add user columns
val userColumns =
columns.filterNot { col ->
listOf(
COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName(),
COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName(),
COLUMN_NAME_AB_META.toSnowflakeCompatibleName(),
COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName()
)
.contains(col)
}
.filter { (k, _) -> if (filterMissing) columns.contains(k) else true }
.map { it.value }
userColumns.forEach { columnName ->
val value = record[columnName] ?: if (!filterMissing) NullValue else null
if (value != null || !filterMissing) {
result.add(value?.toCsvValue() ?: "")
}
}
return result
}
}

View File

@@ -1,45 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.write.transform
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.mockk.every
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
internal class SnowflakeColumnNameMapperTest {
@Test
fun testGetMappedColumnName() {
val columnName = "tést-column-name"
val expectedName = "test-column-name"
val stream = mockk<DestinationStream>()
val tableCatalog = mockk<TableCatalog>()
val snowflakeConfiguration = mockk<SnowflakeConfiguration>(relaxed = true)
// Configure the mock to return the expected mapped column name
every { tableCatalog.getMappedColumnName(stream, columnName) } returns expectedName
val mapper = SnowflakeColumnNameMapper(tableCatalog, snowflakeConfiguration)
val result = mapper.getMappedColumnName(stream = stream, columnName = columnName)
assertEquals(expectedName, result)
}
@Test
fun testGetMappedColumnNameRawFormat() {
val columnName = "tést-column-name"
val stream = mockk<DestinationStream>()
val tableCatalog = mockk<TableCatalog>()
val snowflakeConfiguration =
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns true }
val mapper = SnowflakeColumnNameMapper(tableCatalog, snowflakeConfiguration)
val result = mapper.getMappedColumnName(stream = stream, columnName = columnName)
assertEquals(columnName, result)
}
}

View File

@@ -260,6 +260,7 @@ desired namespace.
| Version | Date | Pull Request | Subject |
|:----------------|:-----------|:--------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 4.0.32-rc.1 | 2025-12-18 | [70903](https://github.com/airbytehq/airbyte/pull/70903) | Upgrade to CDK 0.1.91; internal refactoring |
| 4.0.31 | 2025-12-08 | [70442](https://github.com/airbytehq/airbyte/pull/70442) | Write VARIANT values correctly when underlying Airbyte type is a `union` |
| 4.0.30 | 2025-11-24 | [69842](https://github.com/airbytehq/airbyte/pull/69842) | Update documentation about numeric value handling |
| 4.0.29 | 2025-11-14 | [69342](https://github.com/airbytehq/airbyte/pull/69342) | Truncate NumberValues and IntegerValues with excessive precision instead of nullifying them |