Snowflake maps its schema once at the start. (#70903)
This commit is contained in:
@@ -1,3 +1,3 @@
|
|||||||
testExecutionConcurrency=-1
|
testExecutionConcurrency=-1
|
||||||
cdkVersion=0.1.82
|
cdkVersion=0.1.91
|
||||||
JunitMethodExecutionTimeout=10m
|
JunitMethodExecutionTimeout=10m
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ data:
|
|||||||
connectorSubtype: database
|
connectorSubtype: database
|
||||||
connectorType: destination
|
connectorType: destination
|
||||||
definitionId: 424892c4-daac-4491-b35d-c6688ba547ba
|
definitionId: 424892c4-daac-4491-b35d-c6688ba547ba
|
||||||
dockerImageTag: 4.0.31
|
dockerImageTag: 4.0.32-rc.1
|
||||||
dockerRepository: airbyte/destination-snowflake
|
dockerRepository: airbyte/destination-snowflake
|
||||||
documentationUrl: https://docs.airbyte.com/integrations/destinations/snowflake
|
documentationUrl: https://docs.airbyte.com/integrations/destinations/snowflake
|
||||||
githubIssueLabel: destination-snowflake
|
githubIssueLabel: destination-snowflake
|
||||||
@@ -31,6 +31,8 @@ data:
|
|||||||
enabled: true
|
enabled: true
|
||||||
releaseStage: generally_available
|
releaseStage: generally_available
|
||||||
releases:
|
releases:
|
||||||
|
rolloutConfiguration:
|
||||||
|
enableProgressiveRollout: true
|
||||||
breakingChanges:
|
breakingChanges:
|
||||||
2.0.0:
|
2.0.0:
|
||||||
message: Remove GCS/S3 loading method support.
|
message: Remove GCS/S3 loading method support.
|
||||||
|
|||||||
@@ -11,15 +11,18 @@ import io.airbyte.cdk.load.check.CheckOperationV2
|
|||||||
import io.airbyte.cdk.load.check.DestinationCheckerV2
|
import io.airbyte.cdk.load.check.DestinationCheckerV2
|
||||||
import io.airbyte.cdk.load.config.DataChannelMedium
|
import io.airbyte.cdk.load.config.DataChannelMedium
|
||||||
import io.airbyte.cdk.load.dataflow.config.AggregatePublishingConfig
|
import io.airbyte.cdk.load.dataflow.config.AggregatePublishingConfig
|
||||||
import io.airbyte.cdk.load.orchestration.db.DefaultTempTableNameGenerator
|
import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator
|
||||||
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator
|
import io.airbyte.cdk.load.table.TempTableNameGenerator
|
||||||
import io.airbyte.cdk.output.OutputConsumer
|
import io.airbyte.cdk.output.OutputConsumer
|
||||||
import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier
|
import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier
|
||||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.UsernamePasswordAuthConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.UsernamePasswordAuthConfiguration
|
||||||
|
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRawRecordFormatter
|
||||||
|
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
|
||||||
|
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter
|
||||||
import io.micronaut.context.annotation.Factory
|
import io.micronaut.context.annotation.Factory
|
||||||
import io.micronaut.context.annotation.Primary
|
import io.micronaut.context.annotation.Primary
|
||||||
import io.micronaut.context.annotation.Requires
|
import io.micronaut.context.annotation.Requires
|
||||||
@@ -204,6 +207,17 @@ class SnowflakeBeanFactory {
|
|||||||
outputConsumer: OutputConsumer,
|
outputConsumer: OutputConsumer,
|
||||||
) = CheckOperationV2(destinationChecker, outputConsumer)
|
) = CheckOperationV2(destinationChecker, outputConsumer)
|
||||||
|
|
||||||
|
@Singleton
|
||||||
|
fun snowflakeRecordFormatter(
|
||||||
|
snowflakeConfiguration: SnowflakeConfiguration
|
||||||
|
): SnowflakeRecordFormatter {
|
||||||
|
return if (snowflakeConfiguration.legacyRawTablesOnly) {
|
||||||
|
SnowflakeRawRecordFormatter()
|
||||||
|
} else {
|
||||||
|
SnowflakeSchemaRecordFormatter()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Singleton
|
@Singleton
|
||||||
fun aggregatePublishingConfig(dataChannelMedium: DataChannelMedium): AggregatePublishingConfig {
|
fun aggregatePublishingConfig(dataChannelMedium: DataChannelMedium): AggregatePublishingConfig {
|
||||||
// NOT speed mode
|
// NOT speed mode
|
||||||
|
|||||||
@@ -13,13 +13,17 @@ import io.airbyte.cdk.load.data.FieldType
|
|||||||
import io.airbyte.cdk.load.data.ObjectType
|
import io.airbyte.cdk.load.data.ObjectType
|
||||||
import io.airbyte.cdk.load.data.StringType
|
import io.airbyte.cdk.load.data.StringType
|
||||||
import io.airbyte.cdk.load.message.Meta
|
import io.airbyte.cdk.load.message.Meta
|
||||||
|
import io.airbyte.cdk.load.schema.model.ColumnSchema
|
||||||
|
import io.airbyte.cdk.load.schema.model.StreamTableSchema
|
||||||
|
import io.airbyte.cdk.load.schema.model.TableName
|
||||||
|
import io.airbyte.cdk.load.schema.model.TableNames
|
||||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
import io.airbyte.cdk.load.table.ColumnNameMapping
|
||||||
import io.airbyte.cdk.load.table.TableName
|
|
||||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||||
|
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
|
||||||
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
|
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
|
||||||
|
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter
|
||||||
import jakarta.inject.Singleton
|
import jakarta.inject.Singleton
|
||||||
import java.time.OffsetDateTime
|
import java.time.OffsetDateTime
|
||||||
import java.util.UUID
|
import java.util.UUID
|
||||||
@@ -31,7 +35,7 @@ internal const val CHECK_COLUMN_NAME = "test_key"
|
|||||||
class SnowflakeChecker(
|
class SnowflakeChecker(
|
||||||
private val snowflakeAirbyteClient: SnowflakeAirbyteClient,
|
private val snowflakeAirbyteClient: SnowflakeAirbyteClient,
|
||||||
private val snowflakeConfiguration: SnowflakeConfiguration,
|
private val snowflakeConfiguration: SnowflakeConfiguration,
|
||||||
private val snowflakeColumnUtils: SnowflakeColumnUtils,
|
private val columnManager: SnowflakeColumnManager,
|
||||||
) : DestinationCheckerV2 {
|
) : DestinationCheckerV2 {
|
||||||
|
|
||||||
override fun check() {
|
override fun check() {
|
||||||
@@ -46,11 +50,40 @@ class SnowflakeChecker(
|
|||||||
Meta.AirbyteMetaFields.GENERATION_ID.fieldName to AirbyteValue.from(0),
|
Meta.AirbyteMetaFields.GENERATION_ID.fieldName to AirbyteValue.from(0),
|
||||||
CHECK_COLUMN_NAME.toSnowflakeCompatibleName() to AirbyteValue.from("test-value")
|
CHECK_COLUMN_NAME.toSnowflakeCompatibleName() to AirbyteValue.from("test-value")
|
||||||
)
|
)
|
||||||
val outputSchema = snowflakeConfiguration.schema.toSnowflakeCompatibleName()
|
val outputSchema =
|
||||||
|
if (snowflakeConfiguration.legacyRawTablesOnly) {
|
||||||
|
snowflakeConfiguration.schema
|
||||||
|
} else {
|
||||||
|
snowflakeConfiguration.schema.toSnowflakeCompatibleName()
|
||||||
|
}
|
||||||
val tableName =
|
val tableName =
|
||||||
"_airbyte_connection_test_${
|
"_airbyte_connection_test_${
|
||||||
UUID.randomUUID().toString().replace("-".toRegex(), "")}".toSnowflakeCompatibleName()
|
UUID.randomUUID().toString().replace("-".toRegex(), "")}".toSnowflakeCompatibleName()
|
||||||
val qualifiedTableName = TableName(namespace = outputSchema, name = tableName)
|
val qualifiedTableName = TableName(namespace = outputSchema, name = tableName)
|
||||||
|
val tableSchema =
|
||||||
|
StreamTableSchema(
|
||||||
|
tableNames =
|
||||||
|
TableNames(
|
||||||
|
finalTableName = qualifiedTableName,
|
||||||
|
tempTableName = qualifiedTableName
|
||||||
|
),
|
||||||
|
columnSchema =
|
||||||
|
ColumnSchema(
|
||||||
|
inputToFinalColumnNames =
|
||||||
|
mapOf(
|
||||||
|
CHECK_COLUMN_NAME to CHECK_COLUMN_NAME.toSnowflakeCompatibleName()
|
||||||
|
),
|
||||||
|
finalSchema =
|
||||||
|
mapOf(
|
||||||
|
CHECK_COLUMN_NAME.toSnowflakeCompatibleName() to
|
||||||
|
io.airbyte.cdk.load.component.ColumnType("VARCHAR", false)
|
||||||
|
),
|
||||||
|
inputSchema =
|
||||||
|
mapOf(CHECK_COLUMN_NAME to FieldType(StringType, nullable = false))
|
||||||
|
),
|
||||||
|
importType = Append
|
||||||
|
)
|
||||||
|
|
||||||
val destinationStream =
|
val destinationStream =
|
||||||
DestinationStream(
|
DestinationStream(
|
||||||
unmappedNamespace = outputSchema,
|
unmappedNamespace = outputSchema,
|
||||||
@@ -63,7 +96,8 @@ class SnowflakeChecker(
|
|||||||
generationId = 0L,
|
generationId = 0L,
|
||||||
minimumGenerationId = 0L,
|
minimumGenerationId = 0L,
|
||||||
syncId = 0L,
|
syncId = 0L,
|
||||||
namespaceMapper = NamespaceMapper()
|
namespaceMapper = NamespaceMapper(),
|
||||||
|
tableSchema = tableSchema
|
||||||
)
|
)
|
||||||
runBlocking {
|
runBlocking {
|
||||||
try {
|
try {
|
||||||
@@ -75,14 +109,14 @@ class SnowflakeChecker(
|
|||||||
replace = true,
|
replace = true,
|
||||||
)
|
)
|
||||||
|
|
||||||
val columns = snowflakeAirbyteClient.describeTable(qualifiedTableName)
|
|
||||||
val snowflakeInsertBuffer =
|
val snowflakeInsertBuffer =
|
||||||
SnowflakeInsertBuffer(
|
SnowflakeInsertBuffer(
|
||||||
tableName = qualifiedTableName,
|
tableName = qualifiedTableName,
|
||||||
columns = columns,
|
|
||||||
snowflakeClient = snowflakeAirbyteClient,
|
snowflakeClient = snowflakeAirbyteClient,
|
||||||
snowflakeConfiguration = snowflakeConfiguration,
|
snowflakeConfiguration = snowflakeConfiguration,
|
||||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
columnSchema = tableSchema.columnSchema,
|
||||||
|
columnManager = columnManager,
|
||||||
|
snowflakeRecordFormatter = SnowflakeSchemaRecordFormatter(),
|
||||||
)
|
)
|
||||||
|
|
||||||
snowflakeInsertBuffer.accumulate(data)
|
snowflakeInsertBuffer.accumulate(data)
|
||||||
|
|||||||
@@ -13,18 +13,16 @@ import io.airbyte.cdk.load.component.TableColumns
|
|||||||
import io.airbyte.cdk.load.component.TableOperationsClient
|
import io.airbyte.cdk.load.component.TableOperationsClient
|
||||||
import io.airbyte.cdk.load.component.TableSchema
|
import io.airbyte.cdk.load.component.TableSchema
|
||||||
import io.airbyte.cdk.load.component.TableSchemaEvolutionClient
|
import io.airbyte.cdk.load.component.TableSchemaEvolutionClient
|
||||||
|
import io.airbyte.cdk.load.schema.model.TableName
|
||||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
import io.airbyte.cdk.load.table.ColumnNameMapping
|
||||||
import io.airbyte.cdk.load.table.TableName
|
|
||||||
import io.airbyte.cdk.load.util.deserializeToNode
|
import io.airbyte.cdk.load.util.deserializeToNode
|
||||||
import io.airbyte.integrations.destination.snowflake.db.ColumnDefinition
|
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||||
import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier
|
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS
|
import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.NOT_NULL
|
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
|
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.andLog
|
import io.airbyte.integrations.destination.snowflake.sql.andLog
|
||||||
|
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
|
||||||
import io.github.oshai.kotlinlogging.KotlinLogging
|
import io.github.oshai.kotlinlogging.KotlinLogging
|
||||||
import jakarta.inject.Singleton
|
import jakarta.inject.Singleton
|
||||||
import java.sql.ResultSet
|
import java.sql.ResultSet
|
||||||
@@ -41,13 +39,10 @@ private val log = KotlinLogging.logger {}
|
|||||||
class SnowflakeAirbyteClient(
|
class SnowflakeAirbyteClient(
|
||||||
private val dataSource: DataSource,
|
private val dataSource: DataSource,
|
||||||
private val sqlGenerator: SnowflakeDirectLoadSqlGenerator,
|
private val sqlGenerator: SnowflakeDirectLoadSqlGenerator,
|
||||||
private val snowflakeColumnUtils: SnowflakeColumnUtils,
|
|
||||||
private val snowflakeConfiguration: SnowflakeConfiguration,
|
private val snowflakeConfiguration: SnowflakeConfiguration,
|
||||||
|
private val columnManager: SnowflakeColumnManager,
|
||||||
) : TableOperationsClient, TableSchemaEvolutionClient {
|
) : TableOperationsClient, TableSchemaEvolutionClient {
|
||||||
|
|
||||||
private val airbyteColumnNames =
|
|
||||||
snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet()
|
|
||||||
|
|
||||||
override suspend fun countTable(tableName: TableName): Long? =
|
override suspend fun countTable(tableName: TableName): Long? =
|
||||||
try {
|
try {
|
||||||
dataSource.connection.use { connection ->
|
dataSource.connection.use { connection ->
|
||||||
@@ -126,7 +121,7 @@ class SnowflakeAirbyteClient(
|
|||||||
columnNameMapping: ColumnNameMapping,
|
columnNameMapping: ColumnNameMapping,
|
||||||
replace: Boolean
|
replace: Boolean
|
||||||
) {
|
) {
|
||||||
execute(sqlGenerator.createTable(stream, tableName, columnNameMapping, replace))
|
execute(sqlGenerator.createTable(tableName, stream.tableSchema, replace))
|
||||||
execute(sqlGenerator.createSnowflakeStage(tableName))
|
execute(sqlGenerator.createSnowflakeStage(tableName))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,7 +158,15 @@ class SnowflakeAirbyteClient(
|
|||||||
sourceTableName: TableName,
|
sourceTableName: TableName,
|
||||||
targetTableName: TableName
|
targetTableName: TableName
|
||||||
) {
|
) {
|
||||||
execute(sqlGenerator.copyTable(columnNameMapping, sourceTableName, targetTableName))
|
// Get all column names from the mapping (both meta columns and user columns)
|
||||||
|
val columnNames = buildSet {
|
||||||
|
// Add Airbyte meta columns (using uppercase constants)
|
||||||
|
addAll(columnManager.getMetaColumnNames())
|
||||||
|
// Add user columns from mapping
|
||||||
|
addAll(columnNameMapping.values)
|
||||||
|
}
|
||||||
|
|
||||||
|
execute(sqlGenerator.copyTable(columnNames, sourceTableName, targetTableName))
|
||||||
}
|
}
|
||||||
|
|
||||||
override suspend fun upsertTable(
|
override suspend fun upsertTable(
|
||||||
@@ -172,9 +175,7 @@ class SnowflakeAirbyteClient(
|
|||||||
sourceTableName: TableName,
|
sourceTableName: TableName,
|
||||||
targetTableName: TableName
|
targetTableName: TableName
|
||||||
) {
|
) {
|
||||||
execute(
|
execute(sqlGenerator.upsertTable(stream.tableSchema, sourceTableName, targetTableName))
|
||||||
sqlGenerator.upsertTable(stream, columnNameMapping, sourceTableName, targetTableName)
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override suspend fun dropTable(tableName: TableName) {
|
override suspend fun dropTable(tableName: TableName) {
|
||||||
@@ -206,7 +207,7 @@ class SnowflakeAirbyteClient(
|
|||||||
stream: DestinationStream,
|
stream: DestinationStream,
|
||||||
columnNameMapping: ColumnNameMapping
|
columnNameMapping: ColumnNameMapping
|
||||||
): TableSchema {
|
): TableSchema {
|
||||||
return TableSchema(getColumnsFromStream(stream, columnNameMapping))
|
return TableSchema(stream.tableSchema.columnSchema.finalSchema)
|
||||||
}
|
}
|
||||||
|
|
||||||
override suspend fun applyChangeset(
|
override suspend fun applyChangeset(
|
||||||
@@ -253,7 +254,7 @@ class SnowflakeAirbyteClient(
|
|||||||
val columnName = escapeJsonIdentifier(rs.getString("name"))
|
val columnName = escapeJsonIdentifier(rs.getString("name"))
|
||||||
|
|
||||||
// Filter out airbyte columns
|
// Filter out airbyte columns
|
||||||
if (airbyteColumnNames.contains(columnName)) {
|
if (columnManager.getMetaColumnNames().contains(columnName)) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
val dataType = rs.getString("type").takeWhile { char -> char != '(' }
|
val dataType = rs.getString("type").takeWhile { char -> char != '(' }
|
||||||
@@ -271,49 +272,6 @@ class SnowflakeAirbyteClient(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
internal fun getColumnsFromStream(
|
|
||||||
stream: DestinationStream,
|
|
||||||
columnNameMapping: ColumnNameMapping
|
|
||||||
): Map<String, ColumnType> =
|
|
||||||
snowflakeColumnUtils
|
|
||||||
.columnsAndTypes(stream.schema.asColumns(), columnNameMapping)
|
|
||||||
.filter { column -> column.columnName !in airbyteColumnNames }
|
|
||||||
.associate { column ->
|
|
||||||
// columnsAndTypes returns types as either `FOO` or `FOO NOT NULL`.
|
|
||||||
// so check for that suffix.
|
|
||||||
val nullable = !column.columnType.endsWith(NOT_NULL)
|
|
||||||
val type =
|
|
||||||
column.columnType
|
|
||||||
.takeWhile { char ->
|
|
||||||
// This is to remove any precision parts of the dialect type
|
|
||||||
char != '('
|
|
||||||
}
|
|
||||||
.removeSuffix(NOT_NULL)
|
|
||||||
.trim()
|
|
||||||
|
|
||||||
column.columnName to ColumnType(type, nullable)
|
|
||||||
}
|
|
||||||
|
|
||||||
internal fun generateSchemaChanges(
|
|
||||||
columnsInDb: Set<ColumnDefinition>,
|
|
||||||
columnsInStream: Set<ColumnDefinition>
|
|
||||||
): Triple<Set<ColumnDefinition>, Set<ColumnDefinition>, Set<ColumnDefinition>> {
|
|
||||||
val addedColumns =
|
|
||||||
columnsInStream.filter { it.name !in columnsInDb.map { col -> col.name } }.toSet()
|
|
||||||
val deletedColumns =
|
|
||||||
columnsInDb.filter { it.name !in columnsInStream.map { col -> col.name } }.toSet()
|
|
||||||
val commonColumns =
|
|
||||||
columnsInStream.filter { it.name in columnsInDb.map { col -> col.name } }.toSet()
|
|
||||||
val modifiedColumns =
|
|
||||||
commonColumns
|
|
||||||
.filter {
|
|
||||||
val dbType = columnsInDb.find { column -> it.name == column.name }?.type
|
|
||||||
it.type != dbType
|
|
||||||
}
|
|
||||||
.toSet()
|
|
||||||
return Triple(addedColumns, deletedColumns, modifiedColumns)
|
|
||||||
}
|
|
||||||
|
|
||||||
override suspend fun getGenerationId(tableName: TableName): Long =
|
override suspend fun getGenerationId(tableName: TableName): Long =
|
||||||
try {
|
try {
|
||||||
dataSource.connection.use { connection ->
|
dataSource.connection.use { connection ->
|
||||||
@@ -326,7 +284,7 @@ class SnowflakeAirbyteClient(
|
|||||||
* format. In order to make sure these strings will match any column names
|
* format. In order to make sure these strings will match any column names
|
||||||
* that we have formatted in-memory, re-apply the escaping.
|
* that we have formatted in-memory, re-apply the escaping.
|
||||||
*/
|
*/
|
||||||
resultSet.getLong(snowflakeColumnUtils.getGenerationIdColumnName())
|
resultSet.getLong(columnManager.getGenerationIdColumnName())
|
||||||
} else {
|
} else {
|
||||||
log.warn {
|
log.warn {
|
||||||
"No generation ID found for table ${tableName.toPrettyString()}, returning 0"
|
"No generation ID found for table ${tableName.toPrettyString()}, returning 0"
|
||||||
@@ -351,8 +309,8 @@ class SnowflakeAirbyteClient(
|
|||||||
execute(sqlGenerator.putInStage(tableName, tempFilePath))
|
execute(sqlGenerator.putInStage(tableName, tempFilePath))
|
||||||
}
|
}
|
||||||
|
|
||||||
fun copyFromStage(tableName: TableName, filename: String) {
|
fun copyFromStage(tableName: TableName, filename: String, columnNames: List<String>) {
|
||||||
execute(sqlGenerator.copyFromStage(tableName, filename))
|
execute(sqlGenerator.copyFromStage(tableName, filename, columnNames))
|
||||||
}
|
}
|
||||||
|
|
||||||
fun describeTable(tableName: TableName): LinkedHashMap<String, String> =
|
fun describeTable(tableName: TableName): LinkedHashMap<String, String> =
|
||||||
|
|||||||
@@ -4,47 +4,41 @@
|
|||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.dataflow
|
package io.airbyte.integrations.destination.snowflake.dataflow
|
||||||
|
|
||||||
|
import io.airbyte.cdk.load.command.DestinationCatalog
|
||||||
import io.airbyte.cdk.load.dataflow.aggregate.Aggregate
|
import io.airbyte.cdk.load.dataflow.aggregate.Aggregate
|
||||||
import io.airbyte.cdk.load.dataflow.aggregate.AggregateFactory
|
import io.airbyte.cdk.load.dataflow.aggregate.AggregateFactory
|
||||||
import io.airbyte.cdk.load.dataflow.aggregate.StoreKey
|
import io.airbyte.cdk.load.dataflow.aggregate.StoreKey
|
||||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig
|
import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
|
||||||
import io.airbyte.cdk.load.table.TableName
|
|
||||||
import io.airbyte.cdk.load.write.StreamStateStore
|
import io.airbyte.cdk.load.write.StreamStateStore
|
||||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||||
|
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
|
||||||
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
|
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
|
||||||
import io.micronaut.cache.annotation.CacheConfig
|
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
|
||||||
import io.micronaut.cache.annotation.Cacheable
|
|
||||||
import jakarta.inject.Singleton
|
import jakarta.inject.Singleton
|
||||||
|
|
||||||
@Singleton
|
@Singleton
|
||||||
@CacheConfig("table-columns")
|
class SnowflakeAggregateFactory(
|
||||||
// class has to be open to make the cache stuff work
|
|
||||||
open class SnowflakeAggregateFactory(
|
|
||||||
private val snowflakeClient: SnowflakeAirbyteClient,
|
private val snowflakeClient: SnowflakeAirbyteClient,
|
||||||
private val streamStateStore: StreamStateStore<DirectLoadTableExecutionConfig>,
|
private val streamStateStore: StreamStateStore<DirectLoadTableExecutionConfig>,
|
||||||
private val snowflakeConfiguration: SnowflakeConfiguration,
|
private val snowflakeConfiguration: SnowflakeConfiguration,
|
||||||
private val snowflakeColumnUtils: SnowflakeColumnUtils,
|
private val catalog: DestinationCatalog,
|
||||||
|
private val columnManager: SnowflakeColumnManager,
|
||||||
|
private val snowflakeRecordFormatter: SnowflakeRecordFormatter,
|
||||||
) : AggregateFactory {
|
) : AggregateFactory {
|
||||||
override fun create(key: StoreKey): Aggregate {
|
override fun create(key: StoreKey): Aggregate {
|
||||||
|
val stream = catalog.getStream(key)
|
||||||
val tableName = streamStateStore.get(key)!!.tableName
|
val tableName = streamStateStore.get(key)!!.tableName
|
||||||
|
|
||||||
val buffer =
|
val buffer =
|
||||||
SnowflakeInsertBuffer(
|
SnowflakeInsertBuffer(
|
||||||
tableName = tableName,
|
tableName = tableName,
|
||||||
columns = getTableColumns(tableName),
|
|
||||||
snowflakeClient = snowflakeClient,
|
snowflakeClient = snowflakeClient,
|
||||||
snowflakeConfiguration = snowflakeConfiguration,
|
snowflakeConfiguration = snowflakeConfiguration,
|
||||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
columnSchema = stream.tableSchema.columnSchema,
|
||||||
|
columnManager = columnManager,
|
||||||
|
snowflakeRecordFormatter = snowflakeRecordFormatter,
|
||||||
)
|
)
|
||||||
return SnowflakeAggregate(buffer = buffer)
|
return SnowflakeAggregate(buffer = buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
// We assume that a table isn't getting altered _during_ a sync.
|
|
||||||
// This allows us to only SHOW COLUMNS once per table per sync,
|
|
||||||
// rather than refetching it on every aggregate.
|
|
||||||
@Cacheable
|
|
||||||
// function has to be open to make caching work
|
|
||||||
internal open fun getTableColumns(tableName: TableName) =
|
|
||||||
snowflakeClient.describeTable(tableName)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.db
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Jdbc destination column definition representation
|
|
||||||
*
|
|
||||||
* @param name
|
|
||||||
* @param type
|
|
||||||
*/
|
|
||||||
data class ColumnDefinition(val name: String, val type: String)
|
|
||||||
@@ -4,17 +4,17 @@
|
|||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.db
|
package io.airbyte.integrations.destination.snowflake.db
|
||||||
|
|
||||||
|
import io.airbyte.cdk.load.command.DestinationCatalog
|
||||||
import io.airbyte.cdk.load.component.TableOperationsClient
|
import io.airbyte.cdk.load.component.TableOperationsClient
|
||||||
import io.airbyte.cdk.load.orchestration.db.BaseDirectLoadInitialStatusGatherer
|
import io.airbyte.cdk.load.table.BaseDirectLoadInitialStatusGatherer
|
||||||
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator
|
|
||||||
import jakarta.inject.Singleton
|
import jakarta.inject.Singleton
|
||||||
|
|
||||||
@Singleton
|
@Singleton
|
||||||
class SnowflakeDirectLoadDatabaseInitialStatusGatherer(
|
class SnowflakeDirectLoadDatabaseInitialStatusGatherer(
|
||||||
tableOperationsClient: TableOperationsClient,
|
tableOperationsClient: TableOperationsClient,
|
||||||
tempTableNameGenerator: TempTableNameGenerator,
|
catalog: DestinationCatalog,
|
||||||
) :
|
) :
|
||||||
BaseDirectLoadInitialStatusGatherer(
|
BaseDirectLoadInitialStatusGatherer(
|
||||||
tableOperationsClient,
|
tableOperationsClient,
|
||||||
tempTableNameGenerator,
|
catalog,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,95 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.db
|
|
||||||
|
|
||||||
import io.airbyte.cdk.ConfigErrorException
|
|
||||||
import io.airbyte.cdk.load.command.DestinationStream
|
|
||||||
import io.airbyte.cdk.load.orchestration.db.ColumnNameGenerator
|
|
||||||
import io.airbyte.cdk.load.orchestration.db.FinalTableNameGenerator
|
|
||||||
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TypingDedupingUtil
|
|
||||||
import io.airbyte.cdk.load.table.TableName
|
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.QUOTE
|
|
||||||
import jakarta.inject.Singleton
|
|
||||||
|
|
||||||
@Singleton
|
|
||||||
class SnowflakeFinalTableNameGenerator(private val config: SnowflakeConfiguration) :
|
|
||||||
FinalTableNameGenerator {
|
|
||||||
override fun getTableName(streamDescriptor: DestinationStream.Descriptor): TableName {
|
|
||||||
val namespace = streamDescriptor.namespace ?: config.schema
|
|
||||||
return if (!config.legacyRawTablesOnly) {
|
|
||||||
TableName(
|
|
||||||
namespace = namespace.toSnowflakeCompatibleName(),
|
|
||||||
name = streamDescriptor.name.toSnowflakeCompatibleName(),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
TableName(
|
|
||||||
namespace = config.internalTableSchema,
|
|
||||||
name =
|
|
||||||
TypingDedupingUtil.concatenateRawTableName(
|
|
||||||
namespace = escapeJsonIdentifier(namespace),
|
|
||||||
name = escapeJsonIdentifier(streamDescriptor.name),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Singleton
|
|
||||||
class SnowflakeColumnNameGenerator(private val config: SnowflakeConfiguration) :
|
|
||||||
ColumnNameGenerator {
|
|
||||||
override fun getColumnName(column: String): ColumnNameGenerator.ColumnName {
|
|
||||||
return if (!config.legacyRawTablesOnly) {
|
|
||||||
ColumnNameGenerator.ColumnName(
|
|
||||||
column.toSnowflakeCompatibleName(),
|
|
||||||
column.toSnowflakeCompatibleName(),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
ColumnNameGenerator.ColumnName(
|
|
||||||
column,
|
|
||||||
column,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Escapes double-quotes in a JSON identifier by doubling them. This shit is legacy -- I don't know
|
|
||||||
* why this would be necessary but no harm in keeping it so I am keeping it.
|
|
||||||
*
|
|
||||||
* @return The escaped identifier.
|
|
||||||
*/
|
|
||||||
fun escapeJsonIdentifier(identifier: String): String {
|
|
||||||
// Note that we don't need to escape backslashes here!
|
|
||||||
// The only special character in an identifier is the double-quote, which needs to be
|
|
||||||
// doubled.
|
|
||||||
return identifier.replace(QUOTE, "$QUOTE$QUOTE")
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Transforms a string to be compatible with Snowflake table and column names.
|
|
||||||
*
|
|
||||||
* @return The transformed string suitable for Snowflake identifiers.
|
|
||||||
*/
|
|
||||||
fun String.toSnowflakeCompatibleName(): String {
|
|
||||||
var identifier = this
|
|
||||||
|
|
||||||
// Handle empty strings
|
|
||||||
if (identifier.isEmpty()) {
|
|
||||||
throw ConfigErrorException("Empty string is invalid identifier")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Snowflake scripting language does something weird when the `${` bigram shows up in the
|
|
||||||
// script so replace these with something else.
|
|
||||||
// For completeness, if we trigger this, also replace closing curly braces with underscores.
|
|
||||||
if (identifier.contains("\${")) {
|
|
||||||
identifier = identifier.replace("$", "_").replace("{", "_").replace("}", "_")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Escape double quotes
|
|
||||||
identifier = escapeJsonIdentifier(identifier)
|
|
||||||
|
|
||||||
return identifier.uppercase()
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,116 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package io.airbyte.integrations.destination.snowflake.schema
|
||||||
|
|
||||||
|
import io.airbyte.cdk.load.component.ColumnType
|
||||||
|
import io.airbyte.cdk.load.message.Meta
|
||||||
|
import io.airbyte.cdk.load.schema.model.ColumnSchema
|
||||||
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||||
|
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_EXTRACTED_AT
|
||||||
|
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_GENERATION_ID
|
||||||
|
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_META
|
||||||
|
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_RAW_ID
|
||||||
|
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDataType
|
||||||
|
import jakarta.inject.Singleton
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Manages column names and ordering for Snowflake tables based on whether legacy raw tables mode is
|
||||||
|
* enabled.
|
||||||
|
*
|
||||||
|
* TODO: We should add meta column munging and raw table support to the CDK, so this extra layer of
|
||||||
|
* management shouldn't be necessary.
|
||||||
|
*/
|
||||||
|
@Singleton
|
||||||
|
class SnowflakeColumnManager(
|
||||||
|
private val config: SnowflakeConfiguration,
|
||||||
|
) {
|
||||||
|
/**
|
||||||
|
* Get the list of column names for a table in the order they should appear in the CSV file and
|
||||||
|
* COPY INTO statement.
|
||||||
|
*
|
||||||
|
* Warning: MUST match the order defined in SnowflakeRecordFormatter
|
||||||
|
*
|
||||||
|
* @param columnSchema The schema containing column information (ignored in raw mode)
|
||||||
|
* @return List of column names in the correct order
|
||||||
|
*/
|
||||||
|
fun getTableColumnNames(columnSchema: ColumnSchema): List<String> {
|
||||||
|
return buildList {
|
||||||
|
addAll(getMetaColumnNames())
|
||||||
|
addAll(columnSchema.finalSchema.keys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the list of Airbyte meta column names. In schema mode, these are uppercase. In raw mode,
|
||||||
|
* they are lowercase and included loaded_at
|
||||||
|
*
|
||||||
|
* @return Set of meta column names
|
||||||
|
*/
|
||||||
|
fun getMetaColumnNames(): Set<String> =
|
||||||
|
if (config.legacyRawTablesOnly) {
|
||||||
|
Constants.rawModeMetaColNames
|
||||||
|
} else {
|
||||||
|
Constants.schemaModeMetaColNames
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the Airbyte meta columns as a map of column name to ColumnType. This provides both the
|
||||||
|
* column names and their types for table creation.
|
||||||
|
*
|
||||||
|
* @param columnSchema The user column schema (used to check for CDC columns)
|
||||||
|
* @return Map of meta column names to their types
|
||||||
|
*/
|
||||||
|
fun getMetaColumns(): LinkedHashMap<String, ColumnType> {
|
||||||
|
return if (config.legacyRawTablesOnly) {
|
||||||
|
Constants.rawModeMetaColumns
|
||||||
|
} else {
|
||||||
|
Constants.schemaModeMetaColumns
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun getGenerationIdColumnName(): String {
|
||||||
|
return if (config.legacyRawTablesOnly) {
|
||||||
|
Meta.COLUMN_NAME_AB_GENERATION_ID
|
||||||
|
} else {
|
||||||
|
SNOWFLAKE_AB_GENERATION_ID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
object Constants {
|
||||||
|
val rawModeMetaColumns =
|
||||||
|
linkedMapOf(
|
||||||
|
Meta.COLUMN_NAME_AB_RAW_ID to ColumnType(SnowflakeDataType.VARCHAR.typeName, false),
|
||||||
|
Meta.COLUMN_NAME_AB_EXTRACTED_AT to
|
||||||
|
ColumnType(
|
||||||
|
SnowflakeDataType.TIMESTAMP_TZ.typeName,
|
||||||
|
false,
|
||||||
|
),
|
||||||
|
Meta.COLUMN_NAME_AB_META to ColumnType(SnowflakeDataType.VARIANT.typeName, false),
|
||||||
|
Meta.COLUMN_NAME_AB_GENERATION_ID to
|
||||||
|
ColumnType(
|
||||||
|
SnowflakeDataType.NUMBER.typeName,
|
||||||
|
true,
|
||||||
|
),
|
||||||
|
Meta.COLUMN_NAME_AB_LOADED_AT to
|
||||||
|
ColumnType(
|
||||||
|
SnowflakeDataType.TIMESTAMP_TZ.typeName,
|
||||||
|
true,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
val schemaModeMetaColumns =
|
||||||
|
linkedMapOf(
|
||||||
|
SNOWFLAKE_AB_RAW_ID to ColumnType(SnowflakeDataType.VARCHAR.typeName, false),
|
||||||
|
SNOWFLAKE_AB_EXTRACTED_AT to
|
||||||
|
ColumnType(SnowflakeDataType.TIMESTAMP_TZ.typeName, false),
|
||||||
|
SNOWFLAKE_AB_META to ColumnType(SnowflakeDataType.VARIANT.typeName, false),
|
||||||
|
SNOWFLAKE_AB_GENERATION_ID to ColumnType(SnowflakeDataType.NUMBER.typeName, true),
|
||||||
|
)
|
||||||
|
|
||||||
|
val rawModeMetaColNames: Set<String> = rawModeMetaColumns.keys
|
||||||
|
|
||||||
|
val schemaModeMetaColNames: Set<String> = schemaModeMetaColumns.keys
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package io.airbyte.integrations.destination.snowflake.schema
|
||||||
|
|
||||||
|
import io.airbyte.cdk.ConfigErrorException
|
||||||
|
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Transforms a string to be compatible with Snowflake table and column names.
|
||||||
|
*
|
||||||
|
* @return The transformed string suitable for Snowflake identifiers.
|
||||||
|
*/
|
||||||
|
fun String.toSnowflakeCompatibleName(): String {
|
||||||
|
var identifier = this
|
||||||
|
|
||||||
|
// Handle empty strings
|
||||||
|
if (identifier.isEmpty()) {
|
||||||
|
throw ConfigErrorException("Empty string is invalid identifier")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Snowflake scripting language does something weird when the `${` bigram shows up in the
|
||||||
|
// script so replace these with something else.
|
||||||
|
// For completeness, if we trigger this, also replace closing curly braces with underscores.
|
||||||
|
if (identifier.contains("\${")) {
|
||||||
|
identifier = identifier.replace("$", "_").replace("{", "_").replace("}", "_")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Escape double quotes
|
||||||
|
identifier = escapeJsonIdentifier(identifier)
|
||||||
|
|
||||||
|
return identifier.uppercase()
|
||||||
|
}
|
||||||
@@ -0,0 +1,121 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package io.airbyte.integrations.destination.snowflake.schema
|
||||||
|
|
||||||
|
import io.airbyte.cdk.load.command.DestinationStream
|
||||||
|
import io.airbyte.cdk.load.component.ColumnType
|
||||||
|
import io.airbyte.cdk.load.data.ArrayType
|
||||||
|
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
|
||||||
|
import io.airbyte.cdk.load.data.BooleanType
|
||||||
|
import io.airbyte.cdk.load.data.DateType
|
||||||
|
import io.airbyte.cdk.load.data.FieldType
|
||||||
|
import io.airbyte.cdk.load.data.IntegerType
|
||||||
|
import io.airbyte.cdk.load.data.NumberType
|
||||||
|
import io.airbyte.cdk.load.data.ObjectType
|
||||||
|
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
|
||||||
|
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
|
||||||
|
import io.airbyte.cdk.load.data.StringType
|
||||||
|
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
|
||||||
|
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
|
||||||
|
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
|
||||||
|
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
|
||||||
|
import io.airbyte.cdk.load.data.UnionType
|
||||||
|
import io.airbyte.cdk.load.data.UnknownType
|
||||||
|
import io.airbyte.cdk.load.message.Meta
|
||||||
|
import io.airbyte.cdk.load.schema.TableSchemaMapper
|
||||||
|
import io.airbyte.cdk.load.schema.model.StreamTableSchema
|
||||||
|
import io.airbyte.cdk.load.schema.model.TableName
|
||||||
|
import io.airbyte.cdk.load.table.TempTableNameGenerator
|
||||||
|
import io.airbyte.cdk.load.table.TypingDedupingUtil
|
||||||
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||||
|
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDataType
|
||||||
|
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
|
||||||
|
import jakarta.inject.Singleton
|
||||||
|
|
||||||
|
@Singleton
|
||||||
|
class SnowflakeTableSchemaMapper(
|
||||||
|
private val config: SnowflakeConfiguration,
|
||||||
|
private val tempTableNameGenerator: TempTableNameGenerator,
|
||||||
|
) : TableSchemaMapper {
|
||||||
|
override fun toFinalTableName(desc: DestinationStream.Descriptor): TableName {
|
||||||
|
val namespace = desc.namespace ?: config.schema
|
||||||
|
return if (!config.legacyRawTablesOnly) {
|
||||||
|
TableName(
|
||||||
|
namespace = namespace.toSnowflakeCompatibleName(),
|
||||||
|
name = desc.name.toSnowflakeCompatibleName(),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
TableName(
|
||||||
|
namespace = config.internalTableSchema,
|
||||||
|
name =
|
||||||
|
TypingDedupingUtil.concatenateRawTableName(
|
||||||
|
namespace = escapeJsonIdentifier(namespace),
|
||||||
|
name = escapeJsonIdentifier(desc.name),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun toTempTableName(tableName: TableName): TableName {
|
||||||
|
return tempTableNameGenerator.generate(tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun toColumnName(name: String): String {
|
||||||
|
return if (!config.legacyRawTablesOnly) {
|
||||||
|
name.toSnowflakeCompatibleName()
|
||||||
|
} else {
|
||||||
|
// In legacy mode, column names are not transformed
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun toColumnType(fieldType: FieldType): ColumnType {
|
||||||
|
val snowflakeType =
|
||||||
|
when (fieldType.type) {
|
||||||
|
// Simple types
|
||||||
|
BooleanType -> SnowflakeDataType.BOOLEAN.typeName
|
||||||
|
IntegerType -> SnowflakeDataType.NUMBER.typeName
|
||||||
|
NumberType -> SnowflakeDataType.FLOAT.typeName
|
||||||
|
StringType -> SnowflakeDataType.VARCHAR.typeName
|
||||||
|
|
||||||
|
// Temporal types
|
||||||
|
DateType -> SnowflakeDataType.DATE.typeName
|
||||||
|
TimeTypeWithTimezone -> SnowflakeDataType.VARCHAR.typeName
|
||||||
|
TimeTypeWithoutTimezone -> SnowflakeDataType.TIME.typeName
|
||||||
|
TimestampTypeWithTimezone -> SnowflakeDataType.TIMESTAMP_TZ.typeName
|
||||||
|
TimestampTypeWithoutTimezone -> SnowflakeDataType.TIMESTAMP_NTZ.typeName
|
||||||
|
|
||||||
|
// Semistructured types
|
||||||
|
is ArrayType,
|
||||||
|
ArrayTypeWithoutSchema -> SnowflakeDataType.ARRAY.typeName
|
||||||
|
is ObjectType,
|
||||||
|
ObjectTypeWithEmptySchema,
|
||||||
|
ObjectTypeWithoutSchema -> SnowflakeDataType.OBJECT.typeName
|
||||||
|
is UnionType -> SnowflakeDataType.VARIANT.typeName
|
||||||
|
is UnknownType -> SnowflakeDataType.VARIANT.typeName
|
||||||
|
}
|
||||||
|
|
||||||
|
return ColumnType(snowflakeType, fieldType.nullable)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun toFinalSchema(tableSchema: StreamTableSchema): StreamTableSchema {
|
||||||
|
if (!config.legacyRawTablesOnly) {
|
||||||
|
return tableSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
return StreamTableSchema(
|
||||||
|
tableNames = tableSchema.tableNames,
|
||||||
|
columnSchema =
|
||||||
|
tableSchema.columnSchema.copy(
|
||||||
|
finalSchema =
|
||||||
|
mapOf(
|
||||||
|
Meta.COLUMN_NAME_DATA to
|
||||||
|
ColumnType(SnowflakeDataType.OBJECT.typeName, false)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
importType = tableSchema.importType,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,201 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.sql
|
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting
|
|
||||||
import io.airbyte.cdk.load.data.AirbyteType
|
|
||||||
import io.airbyte.cdk.load.data.ArrayType
|
|
||||||
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
|
|
||||||
import io.airbyte.cdk.load.data.BooleanType
|
|
||||||
import io.airbyte.cdk.load.data.DateType
|
|
||||||
import io.airbyte.cdk.load.data.FieldType
|
|
||||||
import io.airbyte.cdk.load.data.IntegerType
|
|
||||||
import io.airbyte.cdk.load.data.NumberType
|
|
||||||
import io.airbyte.cdk.load.data.ObjectType
|
|
||||||
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
|
|
||||||
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
|
|
||||||
import io.airbyte.cdk.load.data.StringType
|
|
||||||
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
|
|
||||||
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
|
|
||||||
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
|
|
||||||
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
|
|
||||||
import io.airbyte.cdk.load.data.UnionType
|
|
||||||
import io.airbyte.cdk.load.data.UnknownType
|
|
||||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
|
|
||||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
|
|
||||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT
|
|
||||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
|
|
||||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
|
|
||||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA
|
|
||||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
|
||||||
import io.airbyte.integrations.destination.snowflake.db.SnowflakeColumnNameGenerator
|
|
||||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
|
||||||
import jakarta.inject.Singleton
|
|
||||||
import kotlin.collections.component1
|
|
||||||
import kotlin.collections.component2
|
|
||||||
import kotlin.collections.joinToString
|
|
||||||
import kotlin.collections.map
|
|
||||||
import kotlin.collections.plus
|
|
||||||
|
|
||||||
internal const val NOT_NULL = "NOT NULL"
|
|
||||||
|
|
||||||
internal val DEFAULT_COLUMNS =
|
|
||||||
listOf(
|
|
||||||
ColumnAndType(
|
|
||||||
columnName = COLUMN_NAME_AB_RAW_ID,
|
|
||||||
columnType = "${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL"
|
|
||||||
),
|
|
||||||
ColumnAndType(
|
|
||||||
columnName = COLUMN_NAME_AB_EXTRACTED_AT,
|
|
||||||
columnType = "${SnowflakeDataType.TIMESTAMP_TZ.typeName} $NOT_NULL"
|
|
||||||
),
|
|
||||||
ColumnAndType(
|
|
||||||
columnName = COLUMN_NAME_AB_META,
|
|
||||||
columnType = "${SnowflakeDataType.VARIANT.typeName} $NOT_NULL"
|
|
||||||
),
|
|
||||||
ColumnAndType(
|
|
||||||
columnName = COLUMN_NAME_AB_GENERATION_ID,
|
|
||||||
columnType = SnowflakeDataType.NUMBER.typeName
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
internal val RAW_DATA_COLUMN =
|
|
||||||
ColumnAndType(
|
|
||||||
columnName = COLUMN_NAME_DATA,
|
|
||||||
columnType = "${SnowflakeDataType.VARIANT.typeName} $NOT_NULL"
|
|
||||||
)
|
|
||||||
|
|
||||||
internal val RAW_COLUMNS =
|
|
||||||
listOf(
|
|
||||||
ColumnAndType(
|
|
||||||
columnName = COLUMN_NAME_AB_LOADED_AT,
|
|
||||||
columnType = SnowflakeDataType.TIMESTAMP_TZ.typeName
|
|
||||||
),
|
|
||||||
RAW_DATA_COLUMN
|
|
||||||
)
|
|
||||||
|
|
||||||
@Singleton
|
|
||||||
class SnowflakeColumnUtils(
|
|
||||||
private val snowflakeConfiguration: SnowflakeConfiguration,
|
|
||||||
private val snowflakeColumnNameGenerator: SnowflakeColumnNameGenerator,
|
|
||||||
) {
|
|
||||||
|
|
||||||
@VisibleForTesting
|
|
||||||
internal fun defaultColumns(): List<ColumnAndType> =
|
|
||||||
if (snowflakeConfiguration.legacyRawTablesOnly) {
|
|
||||||
DEFAULT_COLUMNS + RAW_COLUMNS
|
|
||||||
} else {
|
|
||||||
DEFAULT_COLUMNS
|
|
||||||
}
|
|
||||||
|
|
||||||
internal fun formattedDefaultColumns(): List<ColumnAndType> =
|
|
||||||
defaultColumns().map {
|
|
||||||
ColumnAndType(
|
|
||||||
columnName = formatColumnName(it.columnName, false),
|
|
||||||
columnType = it.columnType,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun getGenerationIdColumnName(): String {
|
|
||||||
return if (snowflakeConfiguration.legacyRawTablesOnly) {
|
|
||||||
COLUMN_NAME_AB_GENERATION_ID
|
|
||||||
} else {
|
|
||||||
COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun getColumnNames(columnNameMapping: ColumnNameMapping): String =
|
|
||||||
if (snowflakeConfiguration.legacyRawTablesOnly) {
|
|
||||||
getFormattedDefaultColumnNames(true).joinToString(",")
|
|
||||||
} else {
|
|
||||||
(getFormattedDefaultColumnNames(true) +
|
|
||||||
columnNameMapping.map { (_, actualName) -> actualName.quote() })
|
|
||||||
.joinToString(",")
|
|
||||||
}
|
|
||||||
|
|
||||||
fun getFormattedDefaultColumnNames(quote: Boolean = false): List<String> =
|
|
||||||
defaultColumns().map { formatColumnName(it.columnName, quote) }
|
|
||||||
|
|
||||||
fun getFormattedColumnNames(
|
|
||||||
columns: Map<String, FieldType>,
|
|
||||||
columnNameMapping: ColumnNameMapping,
|
|
||||||
quote: Boolean = true,
|
|
||||||
): List<String> =
|
|
||||||
if (snowflakeConfiguration.legacyRawTablesOnly) {
|
|
||||||
getFormattedDefaultColumnNames(quote)
|
|
||||||
} else {
|
|
||||||
getFormattedDefaultColumnNames(quote) +
|
|
||||||
columns.map { (fieldName, _) ->
|
|
||||||
val columnName = columnNameMapping[fieldName] ?: fieldName
|
|
||||||
if (quote) columnName.quote() else columnName
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun columnsAndTypes(
|
|
||||||
columns: Map<String, FieldType>,
|
|
||||||
columnNameMapping: ColumnNameMapping
|
|
||||||
): List<ColumnAndType> =
|
|
||||||
if (snowflakeConfiguration.legacyRawTablesOnly) {
|
|
||||||
formattedDefaultColumns()
|
|
||||||
} else {
|
|
||||||
formattedDefaultColumns() +
|
|
||||||
columns.map { (fieldName, type) ->
|
|
||||||
val columnName = columnNameMapping[fieldName] ?: fieldName
|
|
||||||
val typeName = toDialectType(type.type)
|
|
||||||
ColumnAndType(
|
|
||||||
columnName = columnName,
|
|
||||||
columnType = if (type.nullable) typeName else "$typeName $NOT_NULL",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun formatColumnName(
|
|
||||||
columnName: String,
|
|
||||||
quote: Boolean = true,
|
|
||||||
): String {
|
|
||||||
val formattedColumnName =
|
|
||||||
if (columnName == COLUMN_NAME_DATA) columnName
|
|
||||||
else snowflakeColumnNameGenerator.getColumnName(columnName).displayName
|
|
||||||
return if (quote) formattedColumnName.quote() else formattedColumnName
|
|
||||||
}
|
|
||||||
|
|
||||||
fun toDialectType(type: AirbyteType): String =
|
|
||||||
when (type) {
|
|
||||||
// Simple types
|
|
||||||
BooleanType -> SnowflakeDataType.BOOLEAN.typeName
|
|
||||||
IntegerType -> SnowflakeDataType.NUMBER.typeName
|
|
||||||
NumberType -> SnowflakeDataType.FLOAT.typeName
|
|
||||||
StringType -> SnowflakeDataType.VARCHAR.typeName
|
|
||||||
|
|
||||||
// Temporal types
|
|
||||||
DateType -> SnowflakeDataType.DATE.typeName
|
|
||||||
TimeTypeWithTimezone -> SnowflakeDataType.VARCHAR.typeName
|
|
||||||
TimeTypeWithoutTimezone -> SnowflakeDataType.TIME.typeName
|
|
||||||
TimestampTypeWithTimezone -> SnowflakeDataType.TIMESTAMP_TZ.typeName
|
|
||||||
TimestampTypeWithoutTimezone -> SnowflakeDataType.TIMESTAMP_NTZ.typeName
|
|
||||||
|
|
||||||
// Semistructured types
|
|
||||||
is ArrayType,
|
|
||||||
ArrayTypeWithoutSchema -> SnowflakeDataType.ARRAY.typeName
|
|
||||||
is ObjectType,
|
|
||||||
ObjectTypeWithEmptySchema,
|
|
||||||
ObjectTypeWithoutSchema -> SnowflakeDataType.OBJECT.typeName
|
|
||||||
is UnionType -> SnowflakeDataType.VARIANT.typeName
|
|
||||||
is UnknownType -> SnowflakeDataType.VARIANT.typeName
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data class ColumnAndType(val columnName: String, val columnType: String) {
|
|
||||||
override fun toString(): String {
|
|
||||||
return "${columnName.quote()} $columnType"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Surrounds the string instance with double quotation marks (e.g. "some string" -> "\"some
|
|
||||||
* string\"").
|
|
||||||
*/
|
|
||||||
fun String.quote() = "$QUOTE$this$QUOTE"
|
|
||||||
@@ -10,7 +10,7 @@ package io.airbyte.integrations.destination.snowflake.sql
|
|||||||
*/
|
*/
|
||||||
enum class SnowflakeDataType(val typeName: String) {
|
enum class SnowflakeDataType(val typeName: String) {
|
||||||
// Numeric types
|
// Numeric types
|
||||||
NUMBER("NUMBER(38,0)"),
|
NUMBER("NUMBER"),
|
||||||
FLOAT("FLOAT"),
|
FLOAT("FLOAT"),
|
||||||
|
|
||||||
// String & binary types
|
// String & binary types
|
||||||
|
|||||||
@@ -4,16 +4,19 @@
|
|||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.sql
|
package io.airbyte.integrations.destination.snowflake.sql
|
||||||
|
|
||||||
import io.airbyte.cdk.load.command.Dedupe
|
import com.google.common.annotations.VisibleForTesting
|
||||||
import io.airbyte.cdk.load.command.DestinationStream
|
|
||||||
import io.airbyte.cdk.load.component.ColumnType
|
import io.airbyte.cdk.load.component.ColumnType
|
||||||
import io.airbyte.cdk.load.component.ColumnTypeChange
|
import io.airbyte.cdk.load.component.ColumnTypeChange
|
||||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
|
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
|
||||||
|
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
|
||||||
|
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
|
||||||
|
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
|
||||||
|
import io.airbyte.cdk.load.schema.model.StreamTableSchema
|
||||||
|
import io.airbyte.cdk.load.schema.model.TableName
|
||||||
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
|
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
|
||||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
|
||||||
import io.airbyte.cdk.load.table.TableName
|
|
||||||
import io.airbyte.cdk.load.util.UUIDGenerator
|
import io.airbyte.cdk.load.util.UUIDGenerator
|
||||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||||
|
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode
|
import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||||
import io.airbyte.integrations.destination.snowflake.write.load.CSV_FIELD_SEPARATOR
|
import io.airbyte.integrations.destination.snowflake.write.load.CSV_FIELD_SEPARATOR
|
||||||
@@ -22,6 +25,14 @@ import io.github.oshai.kotlinlogging.KotlinLogging
|
|||||||
import jakarta.inject.Singleton
|
import jakarta.inject.Singleton
|
||||||
|
|
||||||
internal const val COUNT_TOTAL_ALIAS = "TOTAL"
|
internal const val COUNT_TOTAL_ALIAS = "TOTAL"
|
||||||
|
internal const val NOT_NULL = "NOT NULL"
|
||||||
|
|
||||||
|
// Snowflake-compatible (uppercase) versions of the Airbyte meta column names
|
||||||
|
internal val SNOWFLAKE_AB_RAW_ID = COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName()
|
||||||
|
internal val SNOWFLAKE_AB_EXTRACTED_AT = COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()
|
||||||
|
internal val SNOWFLAKE_AB_META = COLUMN_NAME_AB_META.toSnowflakeCompatibleName()
|
||||||
|
internal val SNOWFLAKE_AB_GENERATION_ID = COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName()
|
||||||
|
internal val SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN = CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()
|
||||||
|
|
||||||
private val log = KotlinLogging.logger {}
|
private val log = KotlinLogging.logger {}
|
||||||
|
|
||||||
@@ -36,80 +47,91 @@ fun String.andLog(): String {
|
|||||||
|
|
||||||
@Singleton
|
@Singleton
|
||||||
class SnowflakeDirectLoadSqlGenerator(
|
class SnowflakeDirectLoadSqlGenerator(
|
||||||
private val columnUtils: SnowflakeColumnUtils,
|
|
||||||
private val uuidGenerator: UUIDGenerator,
|
private val uuidGenerator: UUIDGenerator,
|
||||||
private val snowflakeConfiguration: SnowflakeConfiguration,
|
private val config: SnowflakeConfiguration,
|
||||||
private val snowflakeSqlNameUtils: SnowflakeSqlNameUtils,
|
private val columnManager: SnowflakeColumnManager,
|
||||||
) {
|
) {
|
||||||
fun countTable(tableName: TableName): String {
|
fun countTable(tableName: TableName): String {
|
||||||
return "SELECT COUNT(*) AS $COUNT_TOTAL_ALIAS FROM ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog()
|
return "SELECT COUNT(*) AS $COUNT_TOTAL_ALIAS FROM ${fullyQualifiedName(tableName)}".andLog()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun createNamespace(namespace: String): String {
|
fun createNamespace(namespace: String): String {
|
||||||
return "CREATE SCHEMA IF NOT EXISTS ${snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)}".andLog()
|
return "CREATE SCHEMA IF NOT EXISTS ${fullyQualifiedNamespace(namespace)}".andLog()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun createTable(
|
fun createTable(
|
||||||
stream: DestinationStream,
|
|
||||||
tableName: TableName,
|
tableName: TableName,
|
||||||
columnNameMapping: ColumnNameMapping,
|
tableSchema: StreamTableSchema,
|
||||||
replace: Boolean
|
replace: Boolean
|
||||||
): String {
|
): String {
|
||||||
|
val finalSchema = tableSchema.columnSchema.finalSchema
|
||||||
|
val metaColumns = columnManager.getMetaColumns()
|
||||||
|
|
||||||
|
// Build column declarations from the meta columns and user schema
|
||||||
val columnDeclarations =
|
val columnDeclarations =
|
||||||
columnUtils
|
buildList {
|
||||||
.columnsAndTypes(stream.schema.asColumns(), columnNameMapping)
|
// Add Airbyte meta columns from the column manager
|
||||||
.joinToString(",\n")
|
metaColumns.forEach { (columnName, columnType) ->
|
||||||
|
val nullability = if (columnType.nullable) "" else " NOT NULL"
|
||||||
|
add("${columnName.quote()} ${columnType.type}$nullability")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add user columns from the munged schema
|
||||||
|
finalSchema.forEach { (columnName, columnType) ->
|
||||||
|
val nullability = if (columnType.nullable) "" else " NOT NULL"
|
||||||
|
add("${columnName.quote()} ${columnType.type}$nullability")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.joinToString(",\n ")
|
||||||
|
|
||||||
// Snowflake supports CREATE OR REPLACE TABLE, which is simpler than drop+recreate
|
// Snowflake supports CREATE OR REPLACE TABLE, which is simpler than drop+recreate
|
||||||
val createOrReplace = if (replace) "CREATE OR REPLACE" else "CREATE"
|
val createOrReplace = if (replace) "CREATE OR REPLACE" else "CREATE"
|
||||||
|
|
||||||
val createTableStatement =
|
val createTableStatement =
|
||||||
"""
|
"""
|
||||||
$createOrReplace TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)} (
|
|$createOrReplace TABLE ${fullyQualifiedName(tableName)} (
|
||||||
$columnDeclarations
|
| $columnDeclarations
|
||||||
)
|
|)
|
||||||
""".trimIndent()
|
""".trimMargin() // Something was tripping up trimIndent so we opt for trimMargin
|
||||||
|
|
||||||
return createTableStatement.andLog()
|
return createTableStatement.andLog()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun showColumns(tableName: TableName): String =
|
fun showColumns(tableName: TableName): String =
|
||||||
"SHOW COLUMNS IN TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog()
|
"SHOW COLUMNS IN TABLE ${fullyQualifiedName(tableName)}".andLog()
|
||||||
|
|
||||||
fun copyTable(
|
fun copyTable(
|
||||||
columnNameMapping: ColumnNameMapping,
|
columnNames: Set<String>,
|
||||||
sourceTableName: TableName,
|
sourceTableName: TableName,
|
||||||
targetTableName: TableName
|
targetTableName: TableName
|
||||||
): String {
|
): String {
|
||||||
val columnNames = columnUtils.getColumnNames(columnNameMapping)
|
val columnList = columnNames.joinToString(", ") { it.quote() }
|
||||||
|
|
||||||
return """
|
return """
|
||||||
INSERT INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)}
|
INSERT INTO ${fullyQualifiedName(targetTableName)}
|
||||||
(
|
(
|
||||||
$columnNames
|
$columnList
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
$columnNames
|
$columnList
|
||||||
FROM ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)}
|
FROM ${fullyQualifiedName(sourceTableName)}
|
||||||
"""
|
"""
|
||||||
.trimIndent()
|
.trimIndent()
|
||||||
.andLog()
|
.andLog()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun upsertTable(
|
fun upsertTable(
|
||||||
stream: DestinationStream,
|
tableSchema: StreamTableSchema,
|
||||||
columnNameMapping: ColumnNameMapping,
|
|
||||||
sourceTableName: TableName,
|
sourceTableName: TableName,
|
||||||
targetTableName: TableName
|
targetTableName: TableName
|
||||||
): String {
|
): String {
|
||||||
val importType = stream.importType as Dedupe
|
val finalSchema = tableSchema.columnSchema.finalSchema
|
||||||
|
|
||||||
// Build primary key matching condition
|
// Build primary key matching condition
|
||||||
|
val pks = tableSchema.getPrimaryKey().flatten()
|
||||||
val pkEquivalent =
|
val pkEquivalent =
|
||||||
if (importType.primaryKey.isNotEmpty()) {
|
if (pks.isNotEmpty()) {
|
||||||
importType.primaryKey.joinToString(" AND ") { fieldPath ->
|
pks.joinToString(" AND ") { columnName ->
|
||||||
val fieldName = fieldPath.first()
|
|
||||||
val columnName = columnNameMapping[fieldName] ?: fieldName
|
|
||||||
val targetTableColumnName = "target_table.${columnName.quote()}"
|
val targetTableColumnName = "target_table.${columnName.quote()}"
|
||||||
val newRecordColumnName = "new_record.${columnName.quote()}"
|
val newRecordColumnName = "new_record.${columnName.quote()}"
|
||||||
"""($targetTableColumnName = $newRecordColumnName OR ($targetTableColumnName IS NULL AND $newRecordColumnName IS NULL))"""
|
"""($targetTableColumnName = $newRecordColumnName OR ($targetTableColumnName IS NULL AND $newRecordColumnName IS NULL))"""
|
||||||
@@ -120,80 +142,62 @@ class SnowflakeDirectLoadSqlGenerator(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build column lists for INSERT and UPDATE
|
// Build column lists for INSERT and UPDATE
|
||||||
val columnList: String =
|
val allColumns = buildList {
|
||||||
columnUtils
|
add(SNOWFLAKE_AB_RAW_ID)
|
||||||
.getFormattedColumnNames(
|
add(SNOWFLAKE_AB_EXTRACTED_AT)
|
||||||
columns = stream.schema.asColumns(),
|
add(SNOWFLAKE_AB_META)
|
||||||
columnNameMapping = columnNameMapping,
|
add(SNOWFLAKE_AB_GENERATION_ID)
|
||||||
quote = false,
|
addAll(finalSchema.keys)
|
||||||
)
|
}
|
||||||
.joinToString(
|
|
||||||
",\n",
|
|
||||||
) {
|
|
||||||
it.quote()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
val columnList: String = allColumns.joinToString(",\n ") { it.quote() }
|
||||||
val newRecordColumnList: String =
|
val newRecordColumnList: String =
|
||||||
columnUtils
|
allColumns.joinToString(",\n ") { "new_record.${it.quote()}" }
|
||||||
.getFormattedColumnNames(
|
|
||||||
columns = stream.schema.asColumns(),
|
|
||||||
columnNameMapping = columnNameMapping,
|
|
||||||
quote = false,
|
|
||||||
)
|
|
||||||
.joinToString(",\n") { "new_record.${it.quote()}" }
|
|
||||||
|
|
||||||
// Get deduped records from source
|
// Get deduped records from source
|
||||||
val selectSourceRecords = selectDedupedRecords(stream, sourceTableName, columnNameMapping)
|
val selectSourceRecords = selectDedupedRecords(tableSchema, sourceTableName)
|
||||||
|
|
||||||
// Build cursor comparison for determining which record is newer
|
// Build cursor comparison for determining which record is newer
|
||||||
val cursorComparison: String
|
val cursorComparison: String
|
||||||
if (importType.cursor.isNotEmpty()) {
|
val cursor = tableSchema.getCursor().firstOrNull()
|
||||||
val cursorFieldName = importType.cursor.first()
|
if (cursor != null) {
|
||||||
val cursor = (columnNameMapping[cursorFieldName] ?: cursorFieldName)
|
|
||||||
val targetTableCursor = "target_table.${cursor.quote()}"
|
val targetTableCursor = "target_table.${cursor.quote()}"
|
||||||
val newRecordCursor = "new_record.${cursor.quote()}"
|
val newRecordCursor = "new_record.${cursor.quote()}"
|
||||||
cursorComparison =
|
cursorComparison =
|
||||||
"""
|
"""
|
||||||
(
|
(
|
||||||
$targetTableCursor < $newRecordCursor
|
$targetTableCursor < $newRecordCursor
|
||||||
OR ($targetTableCursor = $newRecordCursor AND target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}")
|
OR ($targetTableCursor = $newRecordCursor AND target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT")
|
||||||
OR ($targetTableCursor IS NULL AND $newRecordCursor IS NULL AND target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}")
|
OR ($targetTableCursor IS NULL AND $newRecordCursor IS NULL AND target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT")
|
||||||
OR ($targetTableCursor IS NULL AND $newRecordCursor IS $NOT_NULL)
|
OR ($targetTableCursor IS NULL AND $newRecordCursor IS $NOT_NULL)
|
||||||
)
|
)
|
||||||
""".trimIndent()
|
""".trimIndent()
|
||||||
} else {
|
} else {
|
||||||
// No cursor - use extraction timestamp only
|
// No cursor - use extraction timestamp only
|
||||||
cursorComparison =
|
cursorComparison =
|
||||||
"""target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}""""
|
"""target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT""""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build column assignments for UPDATE
|
// Build column assignments for UPDATE
|
||||||
val columnAssignments: String =
|
val columnAssignments: String =
|
||||||
columnUtils
|
allColumns.joinToString(",\n ") { column ->
|
||||||
.getFormattedColumnNames(
|
"${column.quote()} = new_record.${column.quote()}"
|
||||||
columns = stream.schema.asColumns(),
|
}
|
||||||
columnNameMapping = columnNameMapping,
|
|
||||||
quote = false,
|
|
||||||
)
|
|
||||||
.joinToString(",\n") { column ->
|
|
||||||
"${column.quote()} = new_record.${column.quote()}"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle CDC deletions based on mode
|
// Handle CDC deletions based on mode
|
||||||
val cdcDeleteClause: String
|
val cdcDeleteClause: String
|
||||||
val cdcSkipInsertClause: String
|
val cdcSkipInsertClause: String
|
||||||
if (
|
if (
|
||||||
stream.schema.asColumns().containsKey(CDC_DELETED_AT_COLUMN) &&
|
finalSchema.containsKey(SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN) &&
|
||||||
snowflakeConfiguration.cdcDeletionMode == CdcDeletionMode.HARD_DELETE
|
config.cdcDeletionMode == CdcDeletionMode.HARD_DELETE
|
||||||
) {
|
) {
|
||||||
// Execute CDC deletions if there's already a record
|
// Execute CDC deletions if there's already a record
|
||||||
cdcDeleteClause =
|
cdcDeleteClause =
|
||||||
"WHEN MATCHED AND new_record.\"${CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()}\" IS NOT NULL AND $cursorComparison THEN DELETE"
|
"WHEN MATCHED AND new_record.\"${SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN}\" IS NOT NULL AND $cursorComparison THEN DELETE"
|
||||||
// And skip insertion entirely if there's no matching record.
|
// And skip insertion entirely if there's no matching record.
|
||||||
// (This is possible if a single T+D batch contains both an insertion and deletion for
|
// (This is possible if a single T+D batch contains both an insertion and deletion for
|
||||||
// the same PK)
|
// the same PK)
|
||||||
cdcSkipInsertClause =
|
cdcSkipInsertClause = "AND new_record.\"${SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN}\" IS NULL"
|
||||||
"AND new_record.\"${CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()}\" IS NULL"
|
|
||||||
} else {
|
} else {
|
||||||
cdcDeleteClause = ""
|
cdcDeleteClause = ""
|
||||||
cdcSkipInsertClause = ""
|
cdcSkipInsertClause = ""
|
||||||
@@ -203,35 +207,35 @@ class SnowflakeDirectLoadSqlGenerator(
|
|||||||
val mergeStatement =
|
val mergeStatement =
|
||||||
if (cdcDeleteClause.isNotEmpty()) {
|
if (cdcDeleteClause.isNotEmpty()) {
|
||||||
"""
|
"""
|
||||||
MERGE INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} AS target_table
|
|MERGE INTO ${fullyQualifiedName(targetTableName)} AS target_table
|
||||||
USING (
|
|USING (
|
||||||
$selectSourceRecords
|
|$selectSourceRecords
|
||||||
) AS new_record
|
|) AS new_record
|
||||||
ON $pkEquivalent
|
|ON $pkEquivalent
|
||||||
$cdcDeleteClause
|
|$cdcDeleteClause
|
||||||
WHEN MATCHED AND $cursorComparison THEN UPDATE SET
|
|WHEN MATCHED AND $cursorComparison THEN UPDATE SET
|
||||||
$columnAssignments
|
| $columnAssignments
|
||||||
WHEN NOT MATCHED $cdcSkipInsertClause THEN INSERT (
|
|WHEN NOT MATCHED $cdcSkipInsertClause THEN INSERT (
|
||||||
$columnList
|
| $columnList
|
||||||
) VALUES (
|
|) VALUES (
|
||||||
$newRecordColumnList
|
| $newRecordColumnList
|
||||||
)
|
|)
|
||||||
""".trimIndent()
|
""".trimMargin()
|
||||||
} else {
|
} else {
|
||||||
"""
|
"""
|
||||||
MERGE INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} AS target_table
|
|MERGE INTO ${fullyQualifiedName(targetTableName)} AS target_table
|
||||||
USING (
|
|USING (
|
||||||
$selectSourceRecords
|
|$selectSourceRecords
|
||||||
) AS new_record
|
|) AS new_record
|
||||||
ON $pkEquivalent
|
|ON $pkEquivalent
|
||||||
WHEN MATCHED AND $cursorComparison THEN UPDATE SET
|
|WHEN MATCHED AND $cursorComparison THEN UPDATE SET
|
||||||
$columnAssignments
|
| $columnAssignments
|
||||||
WHEN NOT MATCHED THEN INSERT (
|
|WHEN NOT MATCHED THEN INSERT (
|
||||||
$columnList
|
| $columnList
|
||||||
) VALUES (
|
|) VALUES (
|
||||||
$newRecordColumnList
|
| $newRecordColumnList
|
||||||
)
|
|)
|
||||||
""".trimIndent()
|
""".trimMargin()
|
||||||
}
|
}
|
||||||
|
|
||||||
return mergeStatement.andLog()
|
return mergeStatement.andLog()
|
||||||
@@ -242,75 +246,66 @@ class SnowflakeDirectLoadSqlGenerator(
|
|||||||
* table. Uses ROW_NUMBER() window function to select the most recent record per primary key.
|
* table. Uses ROW_NUMBER() window function to select the most recent record per primary key.
|
||||||
*/
|
*/
|
||||||
private fun selectDedupedRecords(
|
private fun selectDedupedRecords(
|
||||||
stream: DestinationStream,
|
tableSchema: StreamTableSchema,
|
||||||
sourceTableName: TableName,
|
sourceTableName: TableName
|
||||||
columnNameMapping: ColumnNameMapping
|
|
||||||
): String {
|
): String {
|
||||||
val columnList: String =
|
val allColumns = buildList {
|
||||||
columnUtils
|
add(SNOWFLAKE_AB_RAW_ID)
|
||||||
.getFormattedColumnNames(
|
add(SNOWFLAKE_AB_EXTRACTED_AT)
|
||||||
columns = stream.schema.asColumns(),
|
add(SNOWFLAKE_AB_META)
|
||||||
columnNameMapping = columnNameMapping,
|
add(SNOWFLAKE_AB_GENERATION_ID)
|
||||||
quote = false,
|
addAll(tableSchema.columnSchema.finalSchema.keys)
|
||||||
)
|
}
|
||||||
.joinToString(
|
val columnList: String = allColumns.joinToString(",\n ") { it.quote() }
|
||||||
",\n",
|
|
||||||
) {
|
|
||||||
it.quote()
|
|
||||||
}
|
|
||||||
val importType = stream.importType as Dedupe
|
|
||||||
|
|
||||||
// Build the primary key list for partitioning
|
// Build the primary key list for partitioning
|
||||||
|
val pks = tableSchema.getPrimaryKey().flatten()
|
||||||
val pkList =
|
val pkList =
|
||||||
if (importType.primaryKey.isNotEmpty()) {
|
if (pks.isNotEmpty()) {
|
||||||
importType.primaryKey.joinToString(",") { fieldPath ->
|
pks.joinToString(",") { it.quote() }
|
||||||
(columnNameMapping[fieldPath.first()] ?: fieldPath.first()).quote()
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// Should not happen as we check this earlier, but handle it defensively
|
// Should not happen as we check this earlier, but handle it defensively
|
||||||
throw IllegalArgumentException("Cannot deduplicate without primary key")
|
throw IllegalArgumentException("Cannot deduplicate without primary key")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build cursor order clause for sorting within each partition
|
// Build cursor order clause for sorting within each partition
|
||||||
|
val cursor = tableSchema.getCursor().firstOrNull()
|
||||||
val cursorOrderClause =
|
val cursorOrderClause =
|
||||||
if (importType.cursor.isNotEmpty()) {
|
if (cursor != null) {
|
||||||
val columnName =
|
"${cursor.quote()} DESC NULLS LAST,"
|
||||||
(columnNameMapping[importType.cursor.first()] ?: importType.cursor.first())
|
|
||||||
.quote()
|
|
||||||
"$columnName DESC NULLS LAST,"
|
|
||||||
} else {
|
} else {
|
||||||
""
|
""
|
||||||
}
|
}
|
||||||
|
|
||||||
return """
|
return """
|
||||||
WITH records AS (
|
| WITH records AS (
|
||||||
SELECT
|
| SELECT
|
||||||
$columnList
|
| $columnList
|
||||||
FROM ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)}
|
| FROM ${fullyQualifiedName(sourceTableName)}
|
||||||
), numbered_rows AS (
|
| ), numbered_rows AS (
|
||||||
SELECT *, ROW_NUMBER() OVER (
|
| SELECT *, ROW_NUMBER() OVER (
|
||||||
PARTITION BY $pkList ORDER BY $cursorOrderClause "${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" DESC
|
| PARTITION BY $pkList ORDER BY $cursorOrderClause "$SNOWFLAKE_AB_EXTRACTED_AT" DESC
|
||||||
) AS row_number
|
| ) AS row_number
|
||||||
FROM records
|
| FROM records
|
||||||
)
|
| )
|
||||||
SELECT $columnList
|
| SELECT $columnList
|
||||||
FROM numbered_rows
|
| FROM numbered_rows
|
||||||
WHERE row_number = 1
|
| WHERE row_number = 1
|
||||||
"""
|
"""
|
||||||
.trimIndent()
|
.trimMargin()
|
||||||
.andLog()
|
.andLog()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun dropTable(tableName: TableName): String {
|
fun dropTable(tableName: TableName): String {
|
||||||
return "DROP TABLE IF EXISTS ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog()
|
return "DROP TABLE IF EXISTS ${fullyQualifiedName(tableName)}".andLog()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun getGenerationId(
|
fun getGenerationId(
|
||||||
tableName: TableName,
|
tableName: TableName,
|
||||||
): String {
|
): String {
|
||||||
return """
|
return """
|
||||||
SELECT "${columnUtils.getGenerationIdColumnName()}"
|
SELECT "${columnManager.getGenerationIdColumnName()}"
|
||||||
FROM ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}
|
FROM ${fullyQualifiedName(tableName)}
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
"""
|
"""
|
||||||
.trimIndent()
|
.trimIndent()
|
||||||
@@ -318,12 +313,12 @@ class SnowflakeDirectLoadSqlGenerator(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fun createSnowflakeStage(tableName: TableName): String {
|
fun createSnowflakeStage(tableName: TableName): String {
|
||||||
val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName)
|
val stageName = fullyQualifiedStageName(tableName)
|
||||||
return "CREATE STAGE IF NOT EXISTS $stageName".andLog()
|
return "CREATE STAGE IF NOT EXISTS $stageName".andLog()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun putInStage(tableName: TableName, tempFilePath: String): String {
|
fun putInStage(tableName: TableName, tempFilePath: String): String {
|
||||||
val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true)
|
val stageName = fullyQualifiedStageName(tableName, true)
|
||||||
return """
|
return """
|
||||||
PUT 'file://$tempFilePath' '@$stageName'
|
PUT 'file://$tempFilePath' '@$stageName'
|
||||||
AUTO_COMPRESS = FALSE
|
AUTO_COMPRESS = FALSE
|
||||||
@@ -334,35 +329,45 @@ class SnowflakeDirectLoadSqlGenerator(
|
|||||||
.andLog()
|
.andLog()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun copyFromStage(tableName: TableName, filename: String): String {
|
fun copyFromStage(
|
||||||
val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true)
|
tableName: TableName,
|
||||||
|
filename: String,
|
||||||
|
columnNames: List<String>? = null
|
||||||
|
): String {
|
||||||
|
val stageName = fullyQualifiedStageName(tableName, true)
|
||||||
|
val columnList =
|
||||||
|
columnNames?.let { names -> "(${names.joinToString(", ") { it.quote() }})" } ?: ""
|
||||||
|
|
||||||
return """
|
return """
|
||||||
COPY INTO ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}
|
|COPY INTO ${fullyQualifiedName(tableName)}$columnList
|
||||||
FROM '@$stageName'
|
|FROM '@$stageName'
|
||||||
FILE_FORMAT = (
|
|FILE_FORMAT = (
|
||||||
TYPE = 'CSV'
|
| TYPE = 'CSV'
|
||||||
COMPRESSION = GZIP
|
| COMPRESSION = GZIP
|
||||||
FIELD_DELIMITER = '$CSV_FIELD_SEPARATOR'
|
| FIELD_DELIMITER = '$CSV_FIELD_SEPARATOR'
|
||||||
RECORD_DELIMITER = '$CSV_LINE_DELIMITER'
|
| RECORD_DELIMITER = '$CSV_LINE_DELIMITER'
|
||||||
FIELD_OPTIONALLY_ENCLOSED_BY = '"'
|
| FIELD_OPTIONALLY_ENCLOSED_BY = '"'
|
||||||
TRIM_SPACE = TRUE
|
| TRIM_SPACE = TRUE
|
||||||
ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE
|
| ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE
|
||||||
REPLACE_INVALID_CHARACTERS = TRUE
|
| REPLACE_INVALID_CHARACTERS = TRUE
|
||||||
ESCAPE = NONE
|
| ESCAPE = NONE
|
||||||
ESCAPE_UNENCLOSED_FIELD = NONE
|
| ESCAPE_UNENCLOSED_FIELD = NONE
|
||||||
)
|
|)
|
||||||
ON_ERROR = 'ABORT_STATEMENT'
|
|ON_ERROR = 'ABORT_STATEMENT'
|
||||||
PURGE = TRUE
|
|PURGE = TRUE
|
||||||
files = ('$filename')
|
|files = ('$filename')
|
||||||
"""
|
"""
|
||||||
.trimIndent()
|
.trimMargin()
|
||||||
.andLog()
|
.andLog()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun swapTableWith(sourceTableName: TableName, targetTableName: TableName): String {
|
fun swapTableWith(sourceTableName: TableName, targetTableName: TableName): String {
|
||||||
return """
|
return """
|
||||||
ALTER TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} SWAP WITH ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)}
|
ALTER TABLE ${fullyQualifiedName(sourceTableName)} SWAP WITH ${
|
||||||
|
fullyQualifiedName(
|
||||||
|
targetTableName,
|
||||||
|
)
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
.trimIndent()
|
.trimIndent()
|
||||||
.andLog()
|
.andLog()
|
||||||
@@ -372,7 +377,11 @@ class SnowflakeDirectLoadSqlGenerator(
|
|||||||
// Snowflake RENAME TO only accepts the table name, not a fully qualified name
|
// Snowflake RENAME TO only accepts the table name, not a fully qualified name
|
||||||
// The renamed table stays in the same schema
|
// The renamed table stays in the same schema
|
||||||
return """
|
return """
|
||||||
ALTER TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} RENAME TO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)}
|
ALTER TABLE ${fullyQualifiedName(sourceTableName)} RENAME TO ${
|
||||||
|
fullyQualifiedName(
|
||||||
|
targetTableName,
|
||||||
|
)
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
.trimIndent()
|
.trimIndent()
|
||||||
.andLog()
|
.andLog()
|
||||||
@@ -382,7 +391,7 @@ class SnowflakeDirectLoadSqlGenerator(
|
|||||||
schemaName: String,
|
schemaName: String,
|
||||||
tableName: String,
|
tableName: String,
|
||||||
): String =
|
): String =
|
||||||
"""DESCRIBE TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(TableName(schemaName, tableName))}""".andLog()
|
"""DESCRIBE TABLE ${fullyQualifiedName(TableName(schemaName, tableName))}""".andLog()
|
||||||
|
|
||||||
fun alterTable(
|
fun alterTable(
|
||||||
tableName: TableName,
|
tableName: TableName,
|
||||||
@@ -391,14 +400,14 @@ class SnowflakeDirectLoadSqlGenerator(
|
|||||||
modifiedColumns: Map<String, ColumnTypeChange>,
|
modifiedColumns: Map<String, ColumnTypeChange>,
|
||||||
): Set<String> {
|
): Set<String> {
|
||||||
val clauses = mutableSetOf<String>()
|
val clauses = mutableSetOf<String>()
|
||||||
val prettyTableName = snowflakeSqlNameUtils.fullyQualifiedName(tableName)
|
val prettyTableName = fullyQualifiedName(tableName)
|
||||||
addedColumns.forEach { (name, columnType) ->
|
addedColumns.forEach { (name, columnType) ->
|
||||||
clauses.add(
|
clauses.add(
|
||||||
// Note that we intentionally don't set NOT NULL.
|
// Note that we intentionally don't set NOT NULL.
|
||||||
// We're adding a new column, and we don't know what constitutes a reasonable
|
// We're adding a new column, and we don't know what constitutes a reasonable
|
||||||
// default value for preexisting records.
|
// default value for preexisting records.
|
||||||
// So we add the column as nullable.
|
// So we add the column as nullable.
|
||||||
"ALTER TABLE $prettyTableName ADD COLUMN ${name.quote()} ${columnType.type};".andLog()
|
"ALTER TABLE $prettyTableName ADD COLUMN ${name.quote()} ${columnType.type};".andLog(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
deletedColumns.forEach {
|
deletedColumns.forEach {
|
||||||
@@ -412,35 +421,34 @@ class SnowflakeDirectLoadSqlGenerator(
|
|||||||
val tempColumn = "${name}_${uuidGenerator.v4()}"
|
val tempColumn = "${name}_${uuidGenerator.v4()}"
|
||||||
clauses.add(
|
clauses.add(
|
||||||
// As above: we add the column as nullable.
|
// As above: we add the column as nullable.
|
||||||
"ALTER TABLE $prettyTableName ADD COLUMN ${tempColumn.quote()} ${typeChange.newType.type};".andLog()
|
"ALTER TABLE $prettyTableName ADD COLUMN ${tempColumn.quote()} ${typeChange.newType.type};".andLog(),
|
||||||
)
|
)
|
||||||
clauses.add(
|
clauses.add(
|
||||||
"UPDATE $prettyTableName SET ${tempColumn.quote()} = CAST(${name.quote()} AS ${typeChange.newType.type});".andLog()
|
"UPDATE $prettyTableName SET ${tempColumn.quote()} = CAST(${name.quote()} AS ${typeChange.newType.type});".andLog(),
|
||||||
)
|
)
|
||||||
val backupColumn = "${tempColumn}_backup"
|
val backupColumn = "${tempColumn}_backup"
|
||||||
clauses.add(
|
clauses.add(
|
||||||
"""
|
"""
|
||||||
ALTER TABLE $prettyTableName
|
ALTER TABLE $prettyTableName
|
||||||
RENAME COLUMN "$name" TO "$backupColumn";
|
RENAME COLUMN "$name" TO "$backupColumn";
|
||||||
""".trimIndent()
|
""".trimIndent(),
|
||||||
)
|
)
|
||||||
clauses.add(
|
clauses.add(
|
||||||
"""
|
"""
|
||||||
ALTER TABLE $prettyTableName
|
ALTER TABLE $prettyTableName
|
||||||
RENAME COLUMN "$tempColumn" TO "$name";
|
RENAME COLUMN "$tempColumn" TO "$name";
|
||||||
""".trimIndent()
|
""".trimIndent(),
|
||||||
)
|
)
|
||||||
clauses.add(
|
clauses.add(
|
||||||
"ALTER TABLE $prettyTableName DROP COLUMN ${backupColumn.quote()};".andLog()
|
"ALTER TABLE $prettyTableName DROP COLUMN ${backupColumn.quote()};".andLog(),
|
||||||
)
|
)
|
||||||
} else if (!typeChange.originalType.nullable && typeChange.newType.nullable) {
|
} else if (!typeChange.originalType.nullable && typeChange.newType.nullable) {
|
||||||
// If the type is unchanged, we can change a column from NOT NULL to nullable.
|
// If the type is unchanged, we can change a column from NOT NULL to nullable.
|
||||||
// But we'll never do the reverse, because there's a decent chance that historical
|
// But we'll never do the reverse, because there's a decent chance that historical
|
||||||
// records
|
// records had null values.
|
||||||
// had null values.
|
|
||||||
// Users can always manually ALTER COLUMN ... SET NOT NULL if they want.
|
// Users can always manually ALTER COLUMN ... SET NOT NULL if they want.
|
||||||
clauses.add(
|
clauses.add(
|
||||||
"""ALTER TABLE $prettyTableName ALTER COLUMN "$name" DROP NOT NULL;""".andLog()
|
"""ALTER TABLE $prettyTableName ALTER COLUMN "$name" DROP NOT NULL;""".andLog(),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
log.info {
|
log.info {
|
||||||
@@ -450,4 +458,45 @@ class SnowflakeDirectLoadSqlGenerator(
|
|||||||
}
|
}
|
||||||
return clauses
|
return clauses
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
fun fullyQualifiedName(tableName: TableName): String =
|
||||||
|
combineParts(listOf(getDatabaseName(), tableName.namespace, tableName.name))
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
fun fullyQualifiedNamespace(namespace: String) =
|
||||||
|
combineParts(listOf(getDatabaseName(), namespace))
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
fun fullyQualifiedStageName(tableName: TableName, escape: Boolean = false): String {
|
||||||
|
val currentTableName =
|
||||||
|
if (escape) {
|
||||||
|
tableName.name
|
||||||
|
} else {
|
||||||
|
tableName.name
|
||||||
|
}
|
||||||
|
return combineParts(
|
||||||
|
parts =
|
||||||
|
listOf(
|
||||||
|
getDatabaseName(),
|
||||||
|
tableName.namespace,
|
||||||
|
"$STAGE_NAME_PREFIX$currentTableName",
|
||||||
|
),
|
||||||
|
escape = escape,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
internal fun combineParts(parts: List<String>, escape: Boolean = false): String =
|
||||||
|
parts
|
||||||
|
.map { if (escape) sqlEscape(it) else it }
|
||||||
|
.joinToString(separator = ".") {
|
||||||
|
if (!it.startsWith(QUOTE)) {
|
||||||
|
"$QUOTE$it$QUOTE"
|
||||||
|
} else {
|
||||||
|
it
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun getDatabaseName() = config.database.toSnowflakeCompatibleName()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,29 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package io.airbyte.integrations.destination.snowflake.sql
|
||||||
|
|
||||||
|
const val STAGE_NAME_PREFIX = "airbyte_stage_"
|
||||||
|
internal const val QUOTE: String = "\""
|
||||||
|
|
||||||
|
fun sqlEscape(part: String) = part.replace("\\", "\\\\").replace("'", "\\'").replace("\"", "\\\"")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Surrounds the string instance with double quotation marks (e.g. "some string" -> "\"some
|
||||||
|
* string\"").
|
||||||
|
*/
|
||||||
|
fun String.quote() = "$QUOTE$this$QUOTE"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Escapes double-quotes in a JSON identifier by doubling them. This is legacy -- I don't know why
|
||||||
|
* this would be necessary but no harm in keeping it, so I am keeping it.
|
||||||
|
*
|
||||||
|
* @return The escaped identifier.
|
||||||
|
*/
|
||||||
|
fun escapeJsonIdentifier(identifier: String): String {
|
||||||
|
// Note that we don't need to escape backslashes here!
|
||||||
|
// The only special character in an identifier is the double-quote, which needs to be
|
||||||
|
// doubled.
|
||||||
|
return identifier.replace(QUOTE, "$QUOTE$QUOTE")
|
||||||
|
}
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.sql
|
|
||||||
|
|
||||||
import io.airbyte.cdk.load.table.TableName
|
|
||||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
|
||||||
import jakarta.inject.Singleton
|
|
||||||
|
|
||||||
const val STAGE_NAME_PREFIX = "airbyte_stage_"
|
|
||||||
internal const val QUOTE: String = "\""
|
|
||||||
|
|
||||||
fun sqlEscape(part: String) = part.replace("\\", "\\\\").replace("'", "\\'").replace("\"", "\\\"")
|
|
||||||
|
|
||||||
@Singleton
|
|
||||||
class SnowflakeSqlNameUtils(
|
|
||||||
private val snowflakeConfiguration: SnowflakeConfiguration,
|
|
||||||
) {
|
|
||||||
fun fullyQualifiedName(tableName: TableName): String =
|
|
||||||
combineParts(listOf(getDatabaseName(), tableName.namespace, tableName.name))
|
|
||||||
|
|
||||||
fun fullyQualifiedNamespace(namespace: String) =
|
|
||||||
combineParts(listOf(getDatabaseName(), namespace))
|
|
||||||
|
|
||||||
fun fullyQualifiedStageName(tableName: TableName, escape: Boolean = false): String {
|
|
||||||
val currentTableName =
|
|
||||||
if (escape) {
|
|
||||||
tableName.name
|
|
||||||
} else {
|
|
||||||
tableName.name
|
|
||||||
}
|
|
||||||
return combineParts(
|
|
||||||
parts =
|
|
||||||
listOf(
|
|
||||||
getDatabaseName(),
|
|
||||||
tableName.namespace,
|
|
||||||
"$STAGE_NAME_PREFIX$currentTableName"
|
|
||||||
),
|
|
||||||
escape = escape,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun combineParts(parts: List<String>, escape: Boolean = false): String =
|
|
||||||
parts
|
|
||||||
.map { if (escape) sqlEscape(it) else it }
|
|
||||||
.joinToString(separator = ".") {
|
|
||||||
if (!it.startsWith(QUOTE)) {
|
|
||||||
"$QUOTE$it$QUOTE"
|
|
||||||
} else {
|
|
||||||
it
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun getDatabaseName() = snowflakeConfiguration.database.toSnowflakeCompatibleName()
|
|
||||||
}
|
|
||||||
@@ -6,38 +6,39 @@ package io.airbyte.integrations.destination.snowflake.write
|
|||||||
|
|
||||||
import io.airbyte.cdk.SystemErrorException
|
import io.airbyte.cdk.SystemErrorException
|
||||||
import io.airbyte.cdk.load.command.Dedupe
|
import io.airbyte.cdk.load.command.Dedupe
|
||||||
|
import io.airbyte.cdk.load.command.DestinationCatalog
|
||||||
import io.airbyte.cdk.load.command.DestinationStream
|
import io.airbyte.cdk.load.command.DestinationStream
|
||||||
import io.airbyte.cdk.load.orchestration.db.DatabaseInitialStatusGatherer
|
import io.airbyte.cdk.load.table.ColumnNameMapping
|
||||||
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator
|
import io.airbyte.cdk.load.table.DatabaseInitialStatusGatherer
|
||||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadInitialStatus
|
import io.airbyte.cdk.load.table.TempTableNameGenerator
|
||||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendStreamLoader
|
import io.airbyte.cdk.load.table.directload.DirectLoadInitialStatus
|
||||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendTruncateStreamLoader
|
import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendStreamLoader
|
||||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableDedupStreamLoader
|
import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendTruncateStreamLoader
|
||||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableDedupTruncateStreamLoader
|
import io.airbyte.cdk.load.table.directload.DirectLoadTableDedupStreamLoader
|
||||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig
|
import io.airbyte.cdk.load.table.directload.DirectLoadTableDedupTruncateStreamLoader
|
||||||
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog
|
import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
|
||||||
import io.airbyte.cdk.load.write.DestinationWriter
|
import io.airbyte.cdk.load.write.DestinationWriter
|
||||||
import io.airbyte.cdk.load.write.StreamLoader
|
import io.airbyte.cdk.load.write.StreamLoader
|
||||||
import io.airbyte.cdk.load.write.StreamStateStore
|
import io.airbyte.cdk.load.write.StreamStateStore
|
||||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||||
import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier
|
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||||
|
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
|
||||||
import jakarta.inject.Singleton
|
import jakarta.inject.Singleton
|
||||||
|
|
||||||
@Singleton
|
@Singleton
|
||||||
class SnowflakeWriter(
|
class SnowflakeWriter(
|
||||||
private val names: TableCatalog,
|
private val catalog: DestinationCatalog,
|
||||||
private val stateGatherer: DatabaseInitialStatusGatherer<DirectLoadInitialStatus>,
|
private val stateGatherer: DatabaseInitialStatusGatherer<DirectLoadInitialStatus>,
|
||||||
private val streamStateStore: StreamStateStore<DirectLoadTableExecutionConfig>,
|
private val streamStateStore: StreamStateStore<DirectLoadTableExecutionConfig>,
|
||||||
private val snowflakeClient: SnowflakeAirbyteClient,
|
private val snowflakeClient: SnowflakeAirbyteClient,
|
||||||
private val tempTableNameGenerator: TempTableNameGenerator,
|
|
||||||
private val snowflakeConfiguration: SnowflakeConfiguration,
|
private val snowflakeConfiguration: SnowflakeConfiguration,
|
||||||
|
private val tempTableNameGenerator: TempTableNameGenerator,
|
||||||
) : DestinationWriter {
|
) : DestinationWriter {
|
||||||
private lateinit var initialStatuses: Map<DestinationStream, DirectLoadInitialStatus>
|
private lateinit var initialStatuses: Map<DestinationStream, DirectLoadInitialStatus>
|
||||||
|
|
||||||
override suspend fun setup() {
|
override suspend fun setup() {
|
||||||
names.values
|
catalog.streams
|
||||||
.map { (tableNames, _) -> tableNames.finalTableName!!.namespace }
|
.map { it.tableSchema.tableNames.finalTableName!!.namespace }
|
||||||
.toSet()
|
.toSet()
|
||||||
.forEach { snowflakeClient.createNamespace(it) }
|
.forEach { snowflakeClient.createNamespace(it) }
|
||||||
|
|
||||||
@@ -45,15 +46,15 @@ class SnowflakeWriter(
|
|||||||
escapeJsonIdentifier(snowflakeConfiguration.internalTableSchema)
|
escapeJsonIdentifier(snowflakeConfiguration.internalTableSchema)
|
||||||
)
|
)
|
||||||
|
|
||||||
initialStatuses = stateGatherer.gatherInitialStatus(names)
|
initialStatuses = stateGatherer.gatherInitialStatus()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun createStreamLoader(stream: DestinationStream): StreamLoader {
|
override fun createStreamLoader(stream: DestinationStream): StreamLoader {
|
||||||
val initialStatus = initialStatuses[stream]!!
|
val initialStatus = initialStatuses[stream]!!
|
||||||
val tableNameInfo = names[stream]!!
|
val realTableName = stream.tableSchema.tableNames.finalTableName!!
|
||||||
val realTableName = tableNameInfo.tableNames.finalTableName!!
|
val tempTableName = stream.tableSchema.tableNames.tempTableName!!
|
||||||
val tempTableName = tempTableNameGenerator.generate(realTableName)
|
val columnNameMapping =
|
||||||
val columnNameMapping = tableNameInfo.columnNameMapping
|
ColumnNameMapping(stream.tableSchema.columnSchema.inputToFinalColumnNames)
|
||||||
return when (stream.minimumGenerationId) {
|
return when (stream.minimumGenerationId) {
|
||||||
0L ->
|
0L ->
|
||||||
when (stream.importType) {
|
when (stream.importType) {
|
||||||
|
|||||||
@@ -9,11 +9,12 @@ import de.siegmar.fastcsv.writer.CsvWriter
|
|||||||
import de.siegmar.fastcsv.writer.LineDelimiter
|
import de.siegmar.fastcsv.writer.LineDelimiter
|
||||||
import de.siegmar.fastcsv.writer.QuoteStrategies
|
import de.siegmar.fastcsv.writer.QuoteStrategies
|
||||||
import io.airbyte.cdk.load.data.AirbyteValue
|
import io.airbyte.cdk.load.data.AirbyteValue
|
||||||
import io.airbyte.cdk.load.table.TableName
|
import io.airbyte.cdk.load.schema.model.ColumnSchema
|
||||||
|
import io.airbyte.cdk.load.schema.model.TableName
|
||||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||||
|
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.QUOTE
|
import io.airbyte.integrations.destination.snowflake.sql.QUOTE
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
|
||||||
import io.github.oshai.kotlinlogging.KotlinLogging
|
import io.github.oshai.kotlinlogging.KotlinLogging
|
||||||
import java.io.File
|
import java.io.File
|
||||||
import java.io.OutputStream
|
import java.io.OutputStream
|
||||||
@@ -36,10 +37,11 @@ private const val CSV_WRITER_BUFFER_SIZE = 1024 * 1024 // 1 MB
|
|||||||
|
|
||||||
class SnowflakeInsertBuffer(
|
class SnowflakeInsertBuffer(
|
||||||
private val tableName: TableName,
|
private val tableName: TableName,
|
||||||
val columns: LinkedHashMap<String, String>,
|
|
||||||
private val snowflakeClient: SnowflakeAirbyteClient,
|
private val snowflakeClient: SnowflakeAirbyteClient,
|
||||||
val snowflakeConfiguration: SnowflakeConfiguration,
|
val snowflakeConfiguration: SnowflakeConfiguration,
|
||||||
val snowflakeColumnUtils: SnowflakeColumnUtils,
|
val columnSchema: ColumnSchema,
|
||||||
|
private val columnManager: SnowflakeColumnManager,
|
||||||
|
private val snowflakeRecordFormatter: SnowflakeRecordFormatter,
|
||||||
private val flushLimit: Int = DEFAULT_FLUSH_LIMIT,
|
private val flushLimit: Int = DEFAULT_FLUSH_LIMIT,
|
||||||
) {
|
) {
|
||||||
|
|
||||||
@@ -57,12 +59,6 @@ class SnowflakeInsertBuffer(
|
|||||||
.lineDelimiter(CSV_LINE_DELIMITER)
|
.lineDelimiter(CSV_LINE_DELIMITER)
|
||||||
.quoteStrategy(QuoteStrategies.REQUIRED)
|
.quoteStrategy(QuoteStrategies.REQUIRED)
|
||||||
|
|
||||||
private val snowflakeRecordFormatter: SnowflakeRecordFormatter =
|
|
||||||
when (snowflakeConfiguration.legacyRawTablesOnly) {
|
|
||||||
true -> SnowflakeRawRecordFormatter(columns, snowflakeColumnUtils)
|
|
||||||
else -> SnowflakeSchemaRecordFormatter(columns, snowflakeColumnUtils)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun accumulate(recordFields: Map<String, AirbyteValue>) {
|
fun accumulate(recordFields: Map<String, AirbyteValue>) {
|
||||||
if (csvFilePath == null) {
|
if (csvFilePath == null) {
|
||||||
val csvFile = createCsvFile()
|
val csvFile = createCsvFile()
|
||||||
@@ -92,7 +88,9 @@ class SnowflakeInsertBuffer(
|
|||||||
"Copying staging data into ${tableName.toPrettyString(quote = QUOTE)}..."
|
"Copying staging data into ${tableName.toPrettyString(quote = QUOTE)}..."
|
||||||
}
|
}
|
||||||
// Finally, copy the data from the staging table to the final table
|
// Finally, copy the data from the staging table to the final table
|
||||||
snowflakeClient.copyFromStage(tableName, filePath.fileName.toString())
|
// Pass column names to ensure correct mapping even after ALTER TABLE operations
|
||||||
|
val columnNames = columnManager.getTableColumnNames(columnSchema)
|
||||||
|
snowflakeClient.copyFromStage(tableName, filePath.fileName.toString(), columnNames)
|
||||||
logger.info {
|
logger.info {
|
||||||
"Finished insert of $recordCount row(s) into ${tableName.toPrettyString(quote = QUOTE)}."
|
"Finished insert of $recordCount row(s) into ${tableName.toPrettyString(quote = QUOTE)}."
|
||||||
}
|
}
|
||||||
@@ -117,7 +115,9 @@ class SnowflakeInsertBuffer(
|
|||||||
|
|
||||||
private fun writeToCsvFile(record: Map<String, AirbyteValue>) {
|
private fun writeToCsvFile(record: Map<String, AirbyteValue>) {
|
||||||
csvWriter?.let {
|
csvWriter?.let {
|
||||||
it.writeRecord(snowflakeRecordFormatter.format(record).map { col -> col.toString() })
|
it.writeRecord(
|
||||||
|
snowflakeRecordFormatter.format(record, columnSchema).map { col -> col.toString() }
|
||||||
|
)
|
||||||
recordCount++
|
recordCount++
|
||||||
if ((recordCount % flushLimit) == 0) {
|
if ((recordCount % flushLimit) == 0) {
|
||||||
it.flush()
|
it.flush()
|
||||||
|
|||||||
@@ -8,99 +8,62 @@ import io.airbyte.cdk.load.data.AirbyteValue
|
|||||||
import io.airbyte.cdk.load.data.NullValue
|
import io.airbyte.cdk.load.data.NullValue
|
||||||
import io.airbyte.cdk.load.data.StringValue
|
import io.airbyte.cdk.load.data.StringValue
|
||||||
import io.airbyte.cdk.load.data.csv.toCsvValue
|
import io.airbyte.cdk.load.data.csv.toCsvValue
|
||||||
import io.airbyte.cdk.load.message.Meta
|
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
|
||||||
|
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
|
||||||
|
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT
|
||||||
|
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
|
||||||
|
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
|
||||||
|
import io.airbyte.cdk.load.schema.model.ColumnSchema
|
||||||
import io.airbyte.cdk.load.util.Jsons
|
import io.airbyte.cdk.load.util.Jsons
|
||||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
|
||||||
|
|
||||||
interface SnowflakeRecordFormatter {
|
interface SnowflakeRecordFormatter {
|
||||||
fun format(record: Map<String, AirbyteValue>): List<Any>
|
fun format(record: Map<String, AirbyteValue>, columnSchema: ColumnSchema): List<Any>
|
||||||
}
|
}
|
||||||
|
|
||||||
class SnowflakeSchemaRecordFormatter(
|
class SnowflakeSchemaRecordFormatter : SnowflakeRecordFormatter {
|
||||||
private val columns: LinkedHashMap<String, String>,
|
override fun format(record: Map<String, AirbyteValue>, columnSchema: ColumnSchema): List<Any> {
|
||||||
val snowflakeColumnUtils: SnowflakeColumnUtils,
|
val result = mutableListOf<Any>()
|
||||||
) : SnowflakeRecordFormatter {
|
val userColumns = columnSchema.finalSchema.keys
|
||||||
|
|
||||||
private val airbyteColumnNames =
|
// WARNING: MUST match the order defined in SnowflakeColumnManager#getTableColumnNames
|
||||||
snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet()
|
//
|
||||||
|
// Why don't we just use that here? Well, unlike the user fields, the meta fields on the
|
||||||
|
// record are not munged for the destination. So we must access the values for those columns
|
||||||
|
// using the original lowercase meta key.
|
||||||
|
result.add(record[COLUMN_NAME_AB_RAW_ID].toCsvValue())
|
||||||
|
result.add(record[COLUMN_NAME_AB_EXTRACTED_AT].toCsvValue())
|
||||||
|
result.add(record[COLUMN_NAME_AB_META].toCsvValue())
|
||||||
|
result.add(record[COLUMN_NAME_AB_GENERATION_ID].toCsvValue())
|
||||||
|
|
||||||
override fun format(record: Map<String, AirbyteValue>): List<Any> =
|
// Add user columns from the final schema
|
||||||
columns.map { (columnName, _) ->
|
userColumns.forEach { columnName -> result.add(record[columnName].toCsvValue()) }
|
||||||
/*
|
|
||||||
* Meta columns are forced to uppercase for backwards compatibility with previous
|
return result
|
||||||
* versions of the destination. Therefore, convert the column to lowercase so
|
}
|
||||||
* that it can match the constants, which use the lowercase version of the meta
|
|
||||||
* column names.
|
|
||||||
*/
|
|
||||||
if (airbyteColumnNames.contains(columnName)) {
|
|
||||||
record[columnName.lowercase()].toCsvValue()
|
|
||||||
} else {
|
|
||||||
record.keys
|
|
||||||
// The columns retrieved from Snowflake do not have any escaping applied.
|
|
||||||
// Therefore, re-apply the compatible name escaping to the name of the
|
|
||||||
// columns retrieved from Snowflake. The record keys should already have
|
|
||||||
// been escaped by the CDK before arriving at the aggregate, so no need
|
|
||||||
// to escape again here.
|
|
||||||
.find { it == columnName.toSnowflakeCompatibleName() }
|
|
||||||
?.let { record[it].toCsvValue() }
|
|
||||||
?: ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class SnowflakeRawRecordFormatter(
|
class SnowflakeRawRecordFormatter : SnowflakeRecordFormatter {
|
||||||
columns: LinkedHashMap<String, String>,
|
|
||||||
val snowflakeColumnUtils: SnowflakeColumnUtils,
|
|
||||||
) : SnowflakeRecordFormatter {
|
|
||||||
private val columns = columns.keys
|
|
||||||
|
|
||||||
private val airbyteColumnNames =
|
override fun format(record: Map<String, AirbyteValue>, columnSchema: ColumnSchema): List<Any> =
|
||||||
snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet()
|
|
||||||
|
|
||||||
override fun format(record: Map<String, AirbyteValue>): List<Any> =
|
|
||||||
toOutputRecord(record.toMutableMap())
|
toOutputRecord(record.toMutableMap())
|
||||||
|
|
||||||
private fun toOutputRecord(record: MutableMap<String, AirbyteValue>): List<Any> {
|
private fun toOutputRecord(record: MutableMap<String, AirbyteValue>): List<Any> {
|
||||||
val outputRecord = mutableListOf<Any>()
|
val outputRecord = mutableListOf<Any>()
|
||||||
// Copy the Airbyte metadata columns to the raw output, removing each
|
val mutableRecord = record.toMutableMap()
|
||||||
// one from the record to avoid duplicates in the "data" field
|
|
||||||
columns
|
// Add meta columns in order (except _airbyte_data which we handle specially)
|
||||||
.filter { airbyteColumnNames.contains(it) && it != Meta.COLUMN_NAME_DATA }
|
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_RAW_ID)?.toCsvValue() ?: "")
|
||||||
.forEach { column -> safeAddToOutput(column, record, outputRecord) }
|
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_EXTRACTED_AT)?.toCsvValue() ?: "")
|
||||||
|
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_META)?.toCsvValue() ?: "")
|
||||||
|
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_GENERATION_ID)?.toCsvValue() ?: "")
|
||||||
|
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_LOADED_AT)?.toCsvValue() ?: "")
|
||||||
|
|
||||||
// Do not output null values in the JSON raw output
|
// Do not output null values in the JSON raw output
|
||||||
val filteredRecord = record.filter { (_, v) -> v !is NullValue }
|
val filteredRecord = mutableRecord.filter { (_, v) -> v !is NullValue }
|
||||||
// Convert all the remaining columns in the record to a JSON document stored in the "data"
|
|
||||||
// column. Add it in the same position as the _airbyte_data column in the column list to
|
// Convert all the remaining columns to a JSON document stored in the "data" column
|
||||||
// ensure it is inserted into the proper column in the table.
|
outputRecord.add(StringValue(Jsons.writeValueAsString(filteredRecord)).toCsvValue())
|
||||||
insert(
|
|
||||||
columns.indexOf(Meta.COLUMN_NAME_DATA),
|
|
||||||
StringValue(Jsons.writeValueAsString(filteredRecord)).toCsvValue(),
|
|
||||||
outputRecord
|
|
||||||
)
|
|
||||||
return outputRecord
|
return outputRecord
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun safeAddToOutput(
|
|
||||||
key: String,
|
|
||||||
record: MutableMap<String, AirbyteValue>,
|
|
||||||
output: MutableList<Any>
|
|
||||||
) {
|
|
||||||
val extractedValue = record.remove(key)
|
|
||||||
// Ensure that the data is inserted into the list at the same position as the column
|
|
||||||
insert(columns.indexOf(key), extractedValue?.toCsvValue() ?: "", output)
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun insert(index: Int, value: Any, list: MutableList<Any>) {
|
|
||||||
/*
|
|
||||||
* Attempt to insert the value into the proper order in the list. If the index
|
|
||||||
* is already present in the list, use the add(index, element) method to insert it
|
|
||||||
* into the proper order and push everything to the right. If the index is at the
|
|
||||||
* end of the list, just use add(element) to insert it at the end. If the index
|
|
||||||
* is further beyond the end of the list, throw an exception as that should not occur.
|
|
||||||
*/
|
|
||||||
if (index < list.size) list.add(index, value)
|
|
||||||
else if (index == list.size || index == list.size + 1) list.add(value)
|
|
||||||
else throw IndexOutOfBoundsException()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.write.transform
|
|
||||||
|
|
||||||
import io.airbyte.cdk.load.command.DestinationStream
|
|
||||||
import io.airbyte.cdk.load.dataflow.transform.ColumnNameMapper
|
|
||||||
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog
|
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
|
||||||
import jakarta.inject.Singleton
|
|
||||||
|
|
||||||
@Singleton
|
|
||||||
class SnowflakeColumnNameMapper(
|
|
||||||
private val catalogInfo: TableCatalog,
|
|
||||||
private val snowflakeConfiguration: SnowflakeConfiguration,
|
|
||||||
) : ColumnNameMapper {
|
|
||||||
override fun getMappedColumnName(stream: DestinationStream, columnName: String): String {
|
|
||||||
if (snowflakeConfiguration.legacyRawTablesOnly == true) {
|
|
||||||
return columnName
|
|
||||||
} else {
|
|
||||||
return catalogInfo.getMappedColumnName(stream, columnName)!!
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -7,9 +7,11 @@ package io.airbyte.integrations.destination.snowflake.component
|
|||||||
import io.airbyte.cdk.load.component.TableOperationsFixtures
|
import io.airbyte.cdk.load.component.TableOperationsFixtures
|
||||||
import io.airbyte.cdk.load.component.TableOperationsSuite
|
import io.airbyte.cdk.load.component.TableOperationsSuite
|
||||||
import io.airbyte.cdk.load.message.Meta
|
import io.airbyte.cdk.load.message.Meta
|
||||||
|
import io.airbyte.cdk.load.schema.TableSchemaFactory
|
||||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||||
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.idTestWithCdcMapping
|
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.idTestWithCdcMapping
|
||||||
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.testMapping
|
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.testMapping
|
||||||
|
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeTestTableOperationsClient
|
||||||
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
|
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
|
||||||
import org.junit.jupiter.api.Test
|
import org.junit.jupiter.api.Test
|
||||||
import org.junit.jupiter.api.parallel.Execution
|
import org.junit.jupiter.api.parallel.Execution
|
||||||
@@ -20,6 +22,7 @@ import org.junit.jupiter.api.parallel.ExecutionMode
|
|||||||
class SnowflakeTableOperationsTest(
|
class SnowflakeTableOperationsTest(
|
||||||
override val client: SnowflakeAirbyteClient,
|
override val client: SnowflakeAirbyteClient,
|
||||||
override val testClient: SnowflakeTestTableOperationsClient,
|
override val testClient: SnowflakeTestTableOperationsClient,
|
||||||
|
override val schemaFactory: TableSchemaFactory,
|
||||||
) : TableOperationsSuite {
|
) : TableOperationsSuite {
|
||||||
override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() }
|
override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() }
|
||||||
|
|
||||||
|
|||||||
@@ -9,13 +9,15 @@ import io.airbyte.cdk.load.component.TableSchemaEvolutionFixtures
|
|||||||
import io.airbyte.cdk.load.component.TableSchemaEvolutionSuite
|
import io.airbyte.cdk.load.component.TableSchemaEvolutionSuite
|
||||||
import io.airbyte.cdk.load.data.StringValue
|
import io.airbyte.cdk.load.data.StringValue
|
||||||
import io.airbyte.cdk.load.message.Meta
|
import io.airbyte.cdk.load.message.Meta
|
||||||
|
import io.airbyte.cdk.load.schema.TableSchemaFactory
|
||||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
import io.airbyte.cdk.load.table.ColumnNameMapping
|
||||||
import io.airbyte.cdk.load.util.serializeToString
|
import io.airbyte.cdk.load.util.serializeToString
|
||||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||||
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.allTypesColumnNameMapping
|
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.allTypesColumnNameMapping
|
||||||
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.allTypesTableSchema
|
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.allTypesTableSchema
|
||||||
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.idAndTestMapping
|
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.idAndTestMapping
|
||||||
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.testMapping
|
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.testMapping
|
||||||
|
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeTestTableOperationsClient
|
||||||
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
|
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
|
||||||
import org.junit.jupiter.api.Test
|
import org.junit.jupiter.api.Test
|
||||||
import org.junit.jupiter.api.parallel.Execution
|
import org.junit.jupiter.api.parallel.Execution
|
||||||
@@ -27,6 +29,7 @@ class SnowflakeTableSchemaEvolutionTest(
|
|||||||
override val client: SnowflakeAirbyteClient,
|
override val client: SnowflakeAirbyteClient,
|
||||||
override val opsClient: SnowflakeAirbyteClient,
|
override val opsClient: SnowflakeAirbyteClient,
|
||||||
override val testClient: SnowflakeTestTableOperationsClient,
|
override val testClient: SnowflakeTestTableOperationsClient,
|
||||||
|
override val schemaFactory: TableSchemaFactory,
|
||||||
) : TableSchemaEvolutionSuite {
|
) : TableSchemaEvolutionSuite {
|
||||||
override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() }
|
override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() }
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.component
|
package io.airbyte.integrations.destination.snowflake.component.config
|
||||||
|
|
||||||
import io.airbyte.cdk.load.component.ColumnType
|
import io.airbyte.cdk.load.component.ColumnType
|
||||||
import io.airbyte.cdk.load.component.TableOperationsFixtures
|
import io.airbyte.cdk.load.component.TableOperationsFixtures
|
||||||
@@ -33,6 +33,8 @@ object SnowflakeComponentTestFixtures {
|
|||||||
"TIME_NTZ" to ColumnType("TIME", true),
|
"TIME_NTZ" to ColumnType("TIME", true),
|
||||||
"ARRAY" to ColumnType("ARRAY", true),
|
"ARRAY" to ColumnType("ARRAY", true),
|
||||||
"OBJECT" to ColumnType("OBJECT", true),
|
"OBJECT" to ColumnType("OBJECT", true),
|
||||||
|
"UNION" to ColumnType("VARIANT", true),
|
||||||
|
"LEGACY_UNION" to ColumnType("VARIANT", true),
|
||||||
"UNKNOWN" to ColumnType("VARIANT", true),
|
"UNKNOWN" to ColumnType("VARIANT", true),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.component
|
package io.airbyte.integrations.destination.snowflake.component.config
|
||||||
|
|
||||||
import io.airbyte.cdk.load.component.config.TestConfigLoader.loadTestConfig
|
import io.airbyte.cdk.load.component.config.TestConfigLoader.loadTestConfig
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||||
@@ -2,22 +2,24 @@
|
|||||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.component
|
package io.airbyte.integrations.destination.snowflake.component.config
|
||||||
|
|
||||||
|
import io.airbyte.cdk.load.component.ColumnType
|
||||||
import io.airbyte.cdk.load.component.TestTableOperationsClient
|
import io.airbyte.cdk.load.component.TestTableOperationsClient
|
||||||
import io.airbyte.cdk.load.data.AirbyteValue
|
import io.airbyte.cdk.load.data.AirbyteValue
|
||||||
import io.airbyte.cdk.load.dataflow.state.PartitionKey
|
import io.airbyte.cdk.load.dataflow.state.PartitionKey
|
||||||
import io.airbyte.cdk.load.dataflow.transform.RecordDTO
|
import io.airbyte.cdk.load.dataflow.transform.RecordDTO
|
||||||
import io.airbyte.cdk.load.table.TableName
|
import io.airbyte.cdk.load.schema.model.TableName
|
||||||
import io.airbyte.cdk.load.util.Jsons
|
import io.airbyte.cdk.load.util.Jsons
|
||||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||||
import io.airbyte.integrations.destination.snowflake.client.execute
|
import io.airbyte.integrations.destination.snowflake.client.execute
|
||||||
import io.airbyte.integrations.destination.snowflake.dataflow.SnowflakeAggregate
|
import io.airbyte.integrations.destination.snowflake.dataflow.SnowflakeAggregate
|
||||||
|
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils
|
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.andLog
|
import io.airbyte.integrations.destination.snowflake.sql.andLog
|
||||||
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
|
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
|
||||||
|
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
|
||||||
import io.micronaut.context.annotation.Requires
|
import io.micronaut.context.annotation.Requires
|
||||||
import jakarta.inject.Singleton
|
import jakarta.inject.Singleton
|
||||||
import java.time.format.DateTimeFormatter
|
import java.time.format.DateTimeFormatter
|
||||||
@@ -29,25 +31,40 @@ import net.snowflake.client.jdbc.SnowflakeTimestampWithTimezone
|
|||||||
class SnowflakeTestTableOperationsClient(
|
class SnowflakeTestTableOperationsClient(
|
||||||
private val client: SnowflakeAirbyteClient,
|
private val client: SnowflakeAirbyteClient,
|
||||||
private val dataSource: DataSource,
|
private val dataSource: DataSource,
|
||||||
private val snowflakeSqlNameUtils: SnowflakeSqlNameUtils,
|
private val sqlGenerator: SnowflakeDirectLoadSqlGenerator,
|
||||||
private val snowflakeColumnUtils: SnowflakeColumnUtils,
|
|
||||||
private val snowflakeConfiguration: SnowflakeConfiguration,
|
private val snowflakeConfiguration: SnowflakeConfiguration,
|
||||||
|
private val columnManager: SnowflakeColumnManager,
|
||||||
|
private val snowflakeRecordFormatter: SnowflakeRecordFormatter,
|
||||||
) : TestTableOperationsClient {
|
) : TestTableOperationsClient {
|
||||||
override suspend fun dropNamespace(namespace: String) {
|
override suspend fun dropNamespace(namespace: String) {
|
||||||
dataSource.execute(
|
dataSource.execute(
|
||||||
"DROP SCHEMA IF EXISTS ${snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)}".andLog()
|
"DROP SCHEMA IF EXISTS ${sqlGenerator.fullyQualifiedNamespace(namespace)}".andLog()
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
override suspend fun insertRecords(table: TableName, records: List<Map<String, AirbyteValue>>) {
|
override suspend fun insertRecords(table: TableName, records: List<Map<String, AirbyteValue>>) {
|
||||||
|
// TODO: we should just pass a proper column schema
|
||||||
|
// Since we don't pass in a proper column schema, we have to recreate one here
|
||||||
|
// Fetch the columns and filter out the meta columns so we're just looking at user columns
|
||||||
|
val columnTypes =
|
||||||
|
client.describeTable(table).filterNot {
|
||||||
|
columnManager.getMetaColumnNames().contains(it.key)
|
||||||
|
}
|
||||||
|
val columnSchema =
|
||||||
|
io.airbyte.cdk.load.schema.model.ColumnSchema(
|
||||||
|
inputToFinalColumnNames = columnTypes.keys.associateWith { it },
|
||||||
|
finalSchema = columnTypes.mapValues { (_, _) -> ColumnType("", true) },
|
||||||
|
inputSchema = emptyMap() // Not needed for insert buffer
|
||||||
|
)
|
||||||
val a =
|
val a =
|
||||||
SnowflakeAggregate(
|
SnowflakeAggregate(
|
||||||
SnowflakeInsertBuffer(
|
SnowflakeInsertBuffer(
|
||||||
table,
|
tableName = table,
|
||||||
client.describeTable(table),
|
snowflakeClient = client,
|
||||||
client,
|
snowflakeConfiguration = snowflakeConfiguration,
|
||||||
snowflakeConfiguration,
|
columnSchema = columnSchema,
|
||||||
snowflakeColumnUtils,
|
columnManager = columnManager,
|
||||||
|
snowflakeRecordFormatter = snowflakeRecordFormatter,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
records.forEach { a.accept(RecordDTO(it, PartitionKey(""), 0, 0)) }
|
records.forEach { a.accept(RecordDTO(it, PartitionKey(""), 0, 0)) }
|
||||||
@@ -7,11 +7,11 @@ package io.airbyte.integrations.destination.snowflake.write
|
|||||||
import io.airbyte.cdk.load.test.util.DestinationCleaner
|
import io.airbyte.cdk.load.test.util.DestinationCleaner
|
||||||
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
|
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
|
||||||
import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier
|
import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier
|
||||||
import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier
|
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.STAGE_NAME_PREFIX
|
import io.airbyte.integrations.destination.snowflake.sql.STAGE_NAME_PREFIX
|
||||||
|
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.quote
|
import io.airbyte.integrations.destination.snowflake.sql.quote
|
||||||
import java.nio.file.Files
|
import java.nio.file.Files
|
||||||
import java.sql.Connection
|
import java.sql.Connection
|
||||||
|
|||||||
@@ -12,16 +12,22 @@ import io.airbyte.cdk.load.data.ObjectValue
|
|||||||
import io.airbyte.cdk.load.data.json.toAirbyteValue
|
import io.airbyte.cdk.load.data.json.toAirbyteValue
|
||||||
import io.airbyte.cdk.load.message.Meta
|
import io.airbyte.cdk.load.message.Meta
|
||||||
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
|
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
|
||||||
|
import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator
|
||||||
import io.airbyte.cdk.load.test.util.DestinationDataDumper
|
import io.airbyte.cdk.load.test.util.DestinationDataDumper
|
||||||
import io.airbyte.cdk.load.test.util.OutputRecord
|
import io.airbyte.cdk.load.test.util.OutputRecord
|
||||||
|
import io.airbyte.cdk.load.util.UUIDGenerator
|
||||||
import io.airbyte.cdk.load.util.deserializeToNode
|
import io.airbyte.cdk.load.util.deserializeToNode
|
||||||
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
|
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
|
||||||
import io.airbyte.integrations.destination.snowflake.db.SnowflakeFinalTableNameGenerator
|
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeTableSchemaMapper
|
||||||
|
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils
|
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.sqlEscape
|
import io.airbyte.integrations.destination.snowflake.sql.sqlEscape
|
||||||
import java.math.BigDecimal
|
import java.math.BigDecimal
|
||||||
|
import java.sql.Date
|
||||||
|
import java.sql.Time
|
||||||
|
import java.sql.Timestamp
|
||||||
import net.snowflake.client.jdbc.SnowflakeTimestampWithTimezone
|
import net.snowflake.client.jdbc.SnowflakeTimestampWithTimezone
|
||||||
|
|
||||||
private val AIRBYTE_META_COLUMNS = Meta.COLUMN_NAMES + setOf(CDC_DELETED_AT_COLUMN)
|
private val AIRBYTE_META_COLUMNS = Meta.COLUMN_NAMES + setOf(CDC_DELETED_AT_COLUMN)
|
||||||
@@ -34,8 +40,14 @@ class SnowflakeDataDumper(
|
|||||||
stream: DestinationStream
|
stream: DestinationStream
|
||||||
): List<OutputRecord> {
|
): List<OutputRecord> {
|
||||||
val config = configProvider(spec)
|
val config = configProvider(spec)
|
||||||
val sqlUtils = SnowflakeSqlNameUtils(config)
|
val snowflakeFinalTableNameGenerator =
|
||||||
val snowflakeFinalTableNameGenerator = SnowflakeFinalTableNameGenerator(config)
|
SnowflakeTableSchemaMapper(
|
||||||
|
config = config,
|
||||||
|
tempTableNameGenerator = DefaultTempTableNameGenerator(),
|
||||||
|
)
|
||||||
|
val snowflakeColumnManager = SnowflakeColumnManager(config)
|
||||||
|
val sqlGenerator =
|
||||||
|
SnowflakeDirectLoadSqlGenerator(UUIDGenerator(), config, snowflakeColumnManager)
|
||||||
val dataSource =
|
val dataSource =
|
||||||
SnowflakeBeanFactory()
|
SnowflakeBeanFactory()
|
||||||
.snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY")
|
.snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY")
|
||||||
@@ -46,7 +58,7 @@ class SnowflakeDataDumper(
|
|||||||
ds.connection.use { connection ->
|
ds.connection.use { connection ->
|
||||||
val statement = connection.createStatement()
|
val statement = connection.createStatement()
|
||||||
val tableName =
|
val tableName =
|
||||||
snowflakeFinalTableNameGenerator.getTableName(stream.mappedDescriptor)
|
snowflakeFinalTableNameGenerator.toFinalTableName(stream.mappedDescriptor)
|
||||||
|
|
||||||
// First check if the table exists
|
// First check if the table exists
|
||||||
val tableExistsQuery =
|
val tableExistsQuery =
|
||||||
@@ -69,7 +81,7 @@ class SnowflakeDataDumper(
|
|||||||
|
|
||||||
val resultSet =
|
val resultSet =
|
||||||
statement.executeQuery(
|
statement.executeQuery(
|
||||||
"SELECT * FROM ${sqlUtils.fullyQualifiedName(tableName)}"
|
"SELECT * FROM ${sqlGenerator.fullyQualifiedName(tableName)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
while (resultSet.next()) {
|
while (resultSet.next()) {
|
||||||
@@ -143,10 +155,10 @@ class SnowflakeDataDumper(
|
|||||||
private fun convertValue(value: Any?): Any? =
|
private fun convertValue(value: Any?): Any? =
|
||||||
when (value) {
|
when (value) {
|
||||||
is BigDecimal -> value.toBigInteger()
|
is BigDecimal -> value.toBigInteger()
|
||||||
is java.sql.Date -> value.toLocalDate()
|
is Date -> value.toLocalDate()
|
||||||
is SnowflakeTimestampWithTimezone -> value.toZonedDateTime()
|
is SnowflakeTimestampWithTimezone -> value.toZonedDateTime()
|
||||||
is java.sql.Time -> value.toLocalTime()
|
is Time -> value.toLocalTime()
|
||||||
is java.sql.Timestamp -> value.toLocalDateTime()
|
is Timestamp -> value.toLocalDateTime()
|
||||||
else -> value
|
else -> value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import io.airbyte.cdk.load.dataflow.transform.ValidationResult
|
|||||||
import io.airbyte.cdk.load.message.Meta
|
import io.airbyte.cdk.load.message.Meta
|
||||||
import io.airbyte.cdk.load.test.util.ExpectedRecordMapper
|
import io.airbyte.cdk.load.test.util.ExpectedRecordMapper
|
||||||
import io.airbyte.cdk.load.test.util.OutputRecord
|
import io.airbyte.cdk.load.test.util.OutputRecord
|
||||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||||
import io.airbyte.integrations.destination.snowflake.write.transform.SnowflakeValueCoercer
|
import io.airbyte.integrations.destination.snowflake.write.transform.SnowflakeValueCoercer
|
||||||
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange
|
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange
|
||||||
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Change
|
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Change
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
package io.airbyte.integrations.destination.snowflake.write
|
package io.airbyte.integrations.destination.snowflake.write
|
||||||
|
|
||||||
import io.airbyte.cdk.load.test.util.NameMapper
|
import io.airbyte.cdk.load.test.util.NameMapper
|
||||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||||
|
|
||||||
class SnowflakeNameMapper : NameMapper {
|
class SnowflakeNameMapper : NameMapper {
|
||||||
override fun mapFieldName(path: List<String>): List<String> =
|
override fun mapFieldName(path: List<String>): List<String> =
|
||||||
|
|||||||
@@ -9,13 +9,16 @@ import io.airbyte.cdk.load.command.DestinationStream
|
|||||||
import io.airbyte.cdk.load.data.StringValue
|
import io.airbyte.cdk.load.data.StringValue
|
||||||
import io.airbyte.cdk.load.data.json.toAirbyteValue
|
import io.airbyte.cdk.load.data.json.toAirbyteValue
|
||||||
import io.airbyte.cdk.load.message.Meta
|
import io.airbyte.cdk.load.message.Meta
|
||||||
|
import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator
|
||||||
import io.airbyte.cdk.load.test.util.DestinationDataDumper
|
import io.airbyte.cdk.load.test.util.DestinationDataDumper
|
||||||
import io.airbyte.cdk.load.test.util.OutputRecord
|
import io.airbyte.cdk.load.test.util.OutputRecord
|
||||||
|
import io.airbyte.cdk.load.util.UUIDGenerator
|
||||||
import io.airbyte.cdk.load.util.deserializeToNode
|
import io.airbyte.cdk.load.util.deserializeToNode
|
||||||
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
|
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
|
||||||
import io.airbyte.integrations.destination.snowflake.db.SnowflakeFinalTableNameGenerator
|
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||||
|
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeTableSchemaMapper
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils
|
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
|
||||||
|
|
||||||
class SnowflakeRawDataDumper(
|
class SnowflakeRawDataDumper(
|
||||||
private val configProvider: (ConfigurationSpecification) -> SnowflakeConfiguration
|
private val configProvider: (ConfigurationSpecification) -> SnowflakeConfiguration
|
||||||
@@ -27,8 +30,18 @@ class SnowflakeRawDataDumper(
|
|||||||
val output = mutableListOf<OutputRecord>()
|
val output = mutableListOf<OutputRecord>()
|
||||||
|
|
||||||
val config = configProvider(spec)
|
val config = configProvider(spec)
|
||||||
val sqlUtils = SnowflakeSqlNameUtils(config)
|
val snowflakeColumnManager = SnowflakeColumnManager(config)
|
||||||
val snowflakeFinalTableNameGenerator = SnowflakeFinalTableNameGenerator(config)
|
val sqlGenerator =
|
||||||
|
SnowflakeDirectLoadSqlGenerator(
|
||||||
|
UUIDGenerator(),
|
||||||
|
config,
|
||||||
|
snowflakeColumnManager,
|
||||||
|
)
|
||||||
|
val snowflakeFinalTableNameGenerator =
|
||||||
|
SnowflakeTableSchemaMapper(
|
||||||
|
config = config,
|
||||||
|
tempTableNameGenerator = DefaultTempTableNameGenerator(),
|
||||||
|
)
|
||||||
val dataSource =
|
val dataSource =
|
||||||
SnowflakeBeanFactory()
|
SnowflakeBeanFactory()
|
||||||
.snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY")
|
.snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY")
|
||||||
@@ -37,11 +50,11 @@ class SnowflakeRawDataDumper(
|
|||||||
ds.connection.use { connection ->
|
ds.connection.use { connection ->
|
||||||
val statement = connection.createStatement()
|
val statement = connection.createStatement()
|
||||||
val tableName =
|
val tableName =
|
||||||
snowflakeFinalTableNameGenerator.getTableName(stream.mappedDescriptor)
|
snowflakeFinalTableNameGenerator.toFinalTableName(stream.mappedDescriptor)
|
||||||
|
|
||||||
val resultSet =
|
val resultSet =
|
||||||
statement.executeQuery(
|
statement.executeQuery(
|
||||||
"SELECT * FROM ${sqlUtils.fullyQualifiedName(tableName)}"
|
"SELECT * FROM ${sqlGenerator.fullyQualifiedName(tableName)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
while (resultSet.next()) {
|
while (resultSet.next()) {
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ package io.airbyte.integrations.destination.snowflake
|
|||||||
|
|
||||||
import com.zaxxer.hikari.HikariConfig
|
import com.zaxxer.hikari.HikariConfig
|
||||||
import com.zaxxer.hikari.HikariDataSource
|
import com.zaxxer.hikari.HikariDataSource
|
||||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode
|
import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||||
|
|||||||
@@ -5,11 +5,9 @@
|
|||||||
package io.airbyte.integrations.destination.snowflake.check
|
package io.airbyte.integrations.destination.snowflake.check
|
||||||
|
|
||||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||||
|
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.DEFAULT_COLUMNS
|
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.RAW_DATA_COLUMN
|
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
|
||||||
import io.mockk.coEvery
|
import io.mockk.coEvery
|
||||||
import io.mockk.coVerify
|
import io.mockk.coVerify
|
||||||
import io.mockk.every
|
import io.mockk.every
|
||||||
@@ -23,47 +21,31 @@ internal class SnowflakeCheckerTest {
|
|||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@ValueSource(booleans = [true, false])
|
@ValueSource(booleans = [true, false])
|
||||||
fun testSuccessfulCheck(isLegacyRawTablesOnly: Boolean) {
|
fun testSuccessfulCheck(isLegacyRawTablesOnly: Boolean) {
|
||||||
val defaultColumnsMap =
|
|
||||||
if (isLegacyRawTablesOnly) {
|
|
||||||
linkedMapOf<String, String>().also { map ->
|
|
||||||
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
|
|
||||||
map[it.columnName] = it.columnType
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
linkedMapOf<String, String>().also { map ->
|
|
||||||
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
|
|
||||||
map[it.columnName.toSnowflakeCompatibleName()] = it.columnType
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val defaultColumns = defaultColumnsMap.keys.toMutableList()
|
|
||||||
val snowflakeAirbyteClient: SnowflakeAirbyteClient =
|
val snowflakeAirbyteClient: SnowflakeAirbyteClient =
|
||||||
mockk(relaxed = true) {
|
mockk(relaxed = true) { coEvery { countTable(any()) } returns 1L }
|
||||||
coEvery { countTable(any()) } returns 1L
|
|
||||||
coEvery { describeTable(any()) } returns defaultColumnsMap
|
|
||||||
}
|
|
||||||
|
|
||||||
val testSchema = "test-schema"
|
val testSchema = "test-schema"
|
||||||
val snowflakeConfiguration: SnowflakeConfiguration = mockk {
|
val snowflakeConfiguration: SnowflakeConfiguration = mockk {
|
||||||
every { schema } returns testSchema
|
every { schema } returns testSchema
|
||||||
every { legacyRawTablesOnly } returns isLegacyRawTablesOnly
|
every { legacyRawTablesOnly } returns isLegacyRawTablesOnly
|
||||||
}
|
}
|
||||||
val snowflakeColumnUtils =
|
|
||||||
mockk<SnowflakeColumnUtils>(relaxUnitFun = true) {
|
val columnManager: SnowflakeColumnManager = SnowflakeColumnManager(snowflakeConfiguration)
|
||||||
every { getFormattedDefaultColumnNames(any()) } returns defaultColumns
|
|
||||||
}
|
|
||||||
|
|
||||||
val checker =
|
val checker =
|
||||||
SnowflakeChecker(
|
SnowflakeChecker(
|
||||||
snowflakeAirbyteClient = snowflakeAirbyteClient,
|
snowflakeAirbyteClient = snowflakeAirbyteClient,
|
||||||
snowflakeConfiguration = snowflakeConfiguration,
|
snowflakeConfiguration = snowflakeConfiguration,
|
||||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
columnManager = columnManager,
|
||||||
)
|
)
|
||||||
checker.check()
|
checker.check()
|
||||||
|
|
||||||
coVerify(exactly = 1) {
|
coVerify(exactly = 1) {
|
||||||
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
|
if (isLegacyRawTablesOnly) {
|
||||||
|
snowflakeAirbyteClient.createNamespace(testSchema)
|
||||||
|
} else {
|
||||||
|
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) }
|
coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) }
|
||||||
coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) }
|
coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) }
|
||||||
@@ -72,48 +54,32 @@ internal class SnowflakeCheckerTest {
|
|||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@ValueSource(booleans = [true, false])
|
@ValueSource(booleans = [true, false])
|
||||||
fun testUnsuccessfulCheck(isLegacyRawTablesOnly: Boolean) {
|
fun testUnsuccessfulCheck(isLegacyRawTablesOnly: Boolean) {
|
||||||
val defaultColumnsMap =
|
|
||||||
if (isLegacyRawTablesOnly) {
|
|
||||||
linkedMapOf<String, String>().also { map ->
|
|
||||||
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
|
|
||||||
map[it.columnName] = it.columnType
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
linkedMapOf<String, String>().also { map ->
|
|
||||||
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
|
|
||||||
map[it.columnName.toSnowflakeCompatibleName()] = it.columnType
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val defaultColumns = defaultColumnsMap.keys.toMutableList()
|
|
||||||
val snowflakeAirbyteClient: SnowflakeAirbyteClient =
|
val snowflakeAirbyteClient: SnowflakeAirbyteClient =
|
||||||
mockk(relaxed = true) {
|
mockk(relaxed = true) { coEvery { countTable(any()) } returns 0L }
|
||||||
coEvery { countTable(any()) } returns 0L
|
|
||||||
coEvery { describeTable(any()) } returns defaultColumnsMap
|
|
||||||
}
|
|
||||||
|
|
||||||
val testSchema = "test-schema"
|
val testSchema = "test-schema"
|
||||||
val snowflakeConfiguration: SnowflakeConfiguration = mockk {
|
val snowflakeConfiguration: SnowflakeConfiguration = mockk {
|
||||||
every { schema } returns testSchema
|
every { schema } returns testSchema
|
||||||
every { legacyRawTablesOnly } returns isLegacyRawTablesOnly
|
every { legacyRawTablesOnly } returns isLegacyRawTablesOnly
|
||||||
}
|
}
|
||||||
val snowflakeColumnUtils =
|
|
||||||
mockk<SnowflakeColumnUtils>(relaxUnitFun = true) {
|
val columnManager: SnowflakeColumnManager = SnowflakeColumnManager(snowflakeConfiguration)
|
||||||
every { getFormattedDefaultColumnNames(any()) } returns defaultColumns
|
|
||||||
}
|
|
||||||
|
|
||||||
val checker =
|
val checker =
|
||||||
SnowflakeChecker(
|
SnowflakeChecker(
|
||||||
snowflakeAirbyteClient = snowflakeAirbyteClient,
|
snowflakeAirbyteClient = snowflakeAirbyteClient,
|
||||||
snowflakeConfiguration = snowflakeConfiguration,
|
snowflakeConfiguration = snowflakeConfiguration,
|
||||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
columnManager = columnManager,
|
||||||
)
|
)
|
||||||
|
|
||||||
assertThrows<IllegalArgumentException> { checker.check() }
|
assertThrows<IllegalArgumentException> { checker.check() }
|
||||||
|
|
||||||
coVerify(exactly = 1) {
|
coVerify(exactly = 1) {
|
||||||
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
|
if (isLegacyRawTablesOnly) {
|
||||||
|
snowflakeAirbyteClient.createNamespace(testSchema)
|
||||||
|
} else {
|
||||||
|
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) }
|
coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) }
|
||||||
coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) }
|
coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) }
|
||||||
|
|||||||
@@ -6,24 +6,16 @@ package io.airbyte.integrations.destination.snowflake.client
|
|||||||
|
|
||||||
import io.airbyte.cdk.ConfigErrorException
|
import io.airbyte.cdk.ConfigErrorException
|
||||||
import io.airbyte.cdk.load.command.DestinationStream
|
import io.airbyte.cdk.load.command.DestinationStream
|
||||||
import io.airbyte.cdk.load.command.NamespaceMapper
|
|
||||||
import io.airbyte.cdk.load.command.Overwrite
|
|
||||||
import io.airbyte.cdk.load.component.ColumnType
|
import io.airbyte.cdk.load.component.ColumnType
|
||||||
import io.airbyte.cdk.load.config.NamespaceDefinitionType
|
|
||||||
import io.airbyte.cdk.load.data.AirbyteType
|
|
||||||
import io.airbyte.cdk.load.data.FieldType
|
|
||||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
|
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
|
||||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
|
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
|
||||||
|
import io.airbyte.cdk.load.schema.model.TableName
|
||||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
import io.airbyte.cdk.load.table.ColumnNameMapping
|
||||||
import io.airbyte.cdk.load.table.TableName
|
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||||
import io.airbyte.integrations.destination.snowflake.db.ColumnDefinition
|
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS
|
import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.ColumnAndType
|
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.DEFAULT_COLUMNS
|
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.QUOTE
|
import io.airbyte.integrations.destination.snowflake.sql.QUOTE
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
|
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
|
||||||
import io.mockk.Runs
|
import io.mockk.Runs
|
||||||
import io.mockk.every
|
import io.mockk.every
|
||||||
@@ -49,31 +41,18 @@ internal class SnowflakeAirbyteClientTest {
|
|||||||
private lateinit var client: SnowflakeAirbyteClient
|
private lateinit var client: SnowflakeAirbyteClient
|
||||||
private lateinit var dataSource: DataSource
|
private lateinit var dataSource: DataSource
|
||||||
private lateinit var sqlGenerator: SnowflakeDirectLoadSqlGenerator
|
private lateinit var sqlGenerator: SnowflakeDirectLoadSqlGenerator
|
||||||
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
|
|
||||||
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
|
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
|
||||||
|
private lateinit var columnManager: SnowflakeColumnManager
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
fun setup() {
|
fun setup() {
|
||||||
dataSource = mockk()
|
dataSource = mockk()
|
||||||
sqlGenerator = mockk(relaxed = true)
|
sqlGenerator = mockk(relaxed = true)
|
||||||
snowflakeColumnUtils =
|
|
||||||
mockk(relaxed = true) {
|
|
||||||
every { formatColumnName(any()) } answers
|
|
||||||
{
|
|
||||||
firstArg<String>().toSnowflakeCompatibleName()
|
|
||||||
}
|
|
||||||
every { getFormattedDefaultColumnNames(any()) } returns
|
|
||||||
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() }
|
|
||||||
}
|
|
||||||
snowflakeConfiguration =
|
snowflakeConfiguration =
|
||||||
mockk(relaxed = true) { every { database } returns "test_database" }
|
mockk(relaxed = true) { every { database } returns "test_database" }
|
||||||
|
columnManager = mockk(relaxed = true)
|
||||||
client =
|
client =
|
||||||
SnowflakeAirbyteClient(
|
SnowflakeAirbyteClient(dataSource, sqlGenerator, snowflakeConfiguration, columnManager)
|
||||||
dataSource,
|
|
||||||
sqlGenerator,
|
|
||||||
snowflakeColumnUtils,
|
|
||||||
snowflakeConfiguration
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -231,7 +210,7 @@ internal class SnowflakeAirbyteClientTest {
|
|||||||
@Test
|
@Test
|
||||||
fun testCreateTable() {
|
fun testCreateTable() {
|
||||||
val columnNameMapping = mockk<ColumnNameMapping>(relaxed = true)
|
val columnNameMapping = mockk<ColumnNameMapping>(relaxed = true)
|
||||||
val stream = mockk<DestinationStream>()
|
val stream = mockk<DestinationStream>(relaxed = true)
|
||||||
val tableName = TableName(namespace = "namespace", name = "name")
|
val tableName = TableName(namespace = "namespace", name = "name")
|
||||||
val resultSet = mockk<ResultSet>(relaxed = true)
|
val resultSet = mockk<ResultSet>(relaxed = true)
|
||||||
val statement =
|
val statement =
|
||||||
@@ -254,9 +233,7 @@ internal class SnowflakeAirbyteClientTest {
|
|||||||
columnNameMapping = columnNameMapping,
|
columnNameMapping = columnNameMapping,
|
||||||
replace = true,
|
replace = true,
|
||||||
)
|
)
|
||||||
verify(exactly = 1) {
|
verify(exactly = 1) { sqlGenerator.createTable(tableName, any(), true) }
|
||||||
sqlGenerator.createTable(stream, tableName, columnNameMapping, true)
|
|
||||||
}
|
|
||||||
verify(exactly = 1) { sqlGenerator.createSnowflakeStage(tableName) }
|
verify(exactly = 1) { sqlGenerator.createSnowflakeStage(tableName) }
|
||||||
verify(exactly = 2) { mockConnection.close() }
|
verify(exactly = 2) { mockConnection.close() }
|
||||||
}
|
}
|
||||||
@@ -288,7 +265,7 @@ internal class SnowflakeAirbyteClientTest {
|
|||||||
targetTableName = destinationTableName,
|
targetTableName = destinationTableName,
|
||||||
)
|
)
|
||||||
verify(exactly = 1) {
|
verify(exactly = 1) {
|
||||||
sqlGenerator.copyTable(columnNameMapping, sourceTableName, destinationTableName)
|
sqlGenerator.copyTable(any<Set<String>>(), sourceTableName, destinationTableName)
|
||||||
}
|
}
|
||||||
verify(exactly = 1) { mockConnection.close() }
|
verify(exactly = 1) { mockConnection.close() }
|
||||||
}
|
}
|
||||||
@@ -299,7 +276,7 @@ internal class SnowflakeAirbyteClientTest {
|
|||||||
val columnNameMapping = mockk<ColumnNameMapping>(relaxed = true)
|
val columnNameMapping = mockk<ColumnNameMapping>(relaxed = true)
|
||||||
val sourceTableName = TableName(namespace = "namespace", name = "source")
|
val sourceTableName = TableName(namespace = "namespace", name = "source")
|
||||||
val destinationTableName = TableName(namespace = "namespace", name = "destination")
|
val destinationTableName = TableName(namespace = "namespace", name = "destination")
|
||||||
val stream = mockk<DestinationStream>()
|
val stream = mockk<DestinationStream>(relaxed = true)
|
||||||
val resultSet = mockk<ResultSet>(relaxed = true)
|
val resultSet = mockk<ResultSet>(relaxed = true)
|
||||||
val statement =
|
val statement =
|
||||||
mockk<Statement> {
|
mockk<Statement> {
|
||||||
@@ -322,12 +299,7 @@ internal class SnowflakeAirbyteClientTest {
|
|||||||
targetTableName = destinationTableName,
|
targetTableName = destinationTableName,
|
||||||
)
|
)
|
||||||
verify(exactly = 1) {
|
verify(exactly = 1) {
|
||||||
sqlGenerator.upsertTable(
|
sqlGenerator.upsertTable(any(), sourceTableName, destinationTableName)
|
||||||
stream,
|
|
||||||
columnNameMapping,
|
|
||||||
sourceTableName,
|
|
||||||
destinationTableName
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
verify(exactly = 1) { mockConnection.close() }
|
verify(exactly = 1) { mockConnection.close() }
|
||||||
}
|
}
|
||||||
@@ -379,7 +351,7 @@ internal class SnowflakeAirbyteClientTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
every { dataSource.connection } returns mockConnection
|
every { dataSource.connection } returns mockConnection
|
||||||
every { snowflakeColumnUtils.getGenerationIdColumnName() } returns generationIdColumnName
|
every { columnManager.getGenerationIdColumnName() } returns generationIdColumnName
|
||||||
every { sqlGenerator.getGenerationId(tableName) } returns
|
every { sqlGenerator.getGenerationId(tableName) } returns
|
||||||
"SELECT $generationIdColumnName FROM ${tableName.toPrettyString(QUOTE)}"
|
"SELECT $generationIdColumnName FROM ${tableName.toPrettyString(QUOTE)}"
|
||||||
|
|
||||||
@@ -501,8 +473,8 @@ internal class SnowflakeAirbyteClientTest {
|
|||||||
every { dataSource.connection } returns mockConnection
|
every { dataSource.connection } returns mockConnection
|
||||||
|
|
||||||
runBlocking {
|
runBlocking {
|
||||||
client.copyFromStage(tableName, "test.csv.gz")
|
client.copyFromStage(tableName, "test.csv.gz", listOf())
|
||||||
verify(exactly = 1) { sqlGenerator.copyFromStage(tableName, "test.csv.gz") }
|
verify(exactly = 1) { sqlGenerator.copyFromStage(tableName, "test.csv.gz", listOf()) }
|
||||||
verify(exactly = 1) { mockConnection.close() }
|
verify(exactly = 1) { mockConnection.close() }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -556,7 +528,7 @@ internal class SnowflakeAirbyteClientTest {
|
|||||||
"COL1" andThen
|
"COL1" andThen
|
||||||
COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName() andThen
|
COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName() andThen
|
||||||
"COL2"
|
"COL2"
|
||||||
every { resultSet.getString("type") } returns "VARCHAR(255)" andThen "NUMBER(38,0)"
|
every { resultSet.getString("type") } returns "VARCHAR(255)" andThen "NUMBER"
|
||||||
every { resultSet.getString("null?") } returns "Y" andThen "N" andThen "N"
|
every { resultSet.getString("null?") } returns "Y" andThen "N" andThen "N"
|
||||||
|
|
||||||
val statement =
|
val statement =
|
||||||
@@ -571,6 +543,10 @@ internal class SnowflakeAirbyteClientTest {
|
|||||||
|
|
||||||
every { dataSource.connection } returns connection
|
every { dataSource.connection } returns connection
|
||||||
|
|
||||||
|
// Mock the columnManager to return the correct set of meta columns
|
||||||
|
every { columnManager.getMetaColumnNames() } returns
|
||||||
|
setOf(COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName())
|
||||||
|
|
||||||
val result = client.getColumnsFromDb(tableName)
|
val result = client.getColumnsFromDb(tableName)
|
||||||
|
|
||||||
val expectedColumns =
|
val expectedColumns =
|
||||||
@@ -582,81 +558,6 @@ internal class SnowflakeAirbyteClientTest {
|
|||||||
assertEquals(expectedColumns, result)
|
assertEquals(expectedColumns, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
fun `getColumnsFromStream should return correct column definitions`() {
|
|
||||||
val schema = mockk<AirbyteType>()
|
|
||||||
val stream =
|
|
||||||
DestinationStream(
|
|
||||||
unmappedNamespace = "test_namespace",
|
|
||||||
unmappedName = "test_stream",
|
|
||||||
importType = Overwrite,
|
|
||||||
schema = schema,
|
|
||||||
generationId = 1,
|
|
||||||
minimumGenerationId = 1,
|
|
||||||
syncId = 1,
|
|
||||||
namespaceMapper = NamespaceMapper(NamespaceDefinitionType.DESTINATION)
|
|
||||||
)
|
|
||||||
val columnNameMapping =
|
|
||||||
ColumnNameMapping(
|
|
||||||
mapOf(
|
|
||||||
"col1" to "COL1_MAPPED",
|
|
||||||
"col2" to "COL2_MAPPED",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
val col1FieldType = mockk<FieldType>()
|
|
||||||
every { col1FieldType.type } returns mockk()
|
|
||||||
|
|
||||||
val col2FieldType = mockk<FieldType>()
|
|
||||||
every { col2FieldType.type } returns mockk()
|
|
||||||
|
|
||||||
every { schema.asColumns() } returns
|
|
||||||
linkedMapOf("col1" to col1FieldType, "col2" to col2FieldType)
|
|
||||||
every { snowflakeColumnUtils.toDialectType(col1FieldType.type) } returns "VARCHAR(255)"
|
|
||||||
every { snowflakeColumnUtils.toDialectType(col2FieldType.type) } returns "NUMBER(38,0)"
|
|
||||||
every { snowflakeColumnUtils.columnsAndTypes(any(), any()) } returns
|
|
||||||
listOf(ColumnAndType("COL1_MAPPED", "VARCHAR"), ColumnAndType("COL2_MAPPED", "NUMBER"))
|
|
||||||
every { snowflakeColumnUtils.formatColumnName(any(), false) } answers
|
|
||||||
{
|
|
||||||
firstArg<String>().toSnowflakeCompatibleName()
|
|
||||||
}
|
|
||||||
|
|
||||||
val result = client.getColumnsFromStream(stream, columnNameMapping)
|
|
||||||
|
|
||||||
val expectedColumns =
|
|
||||||
mapOf(
|
|
||||||
"COL1_MAPPED" to ColumnType("VARCHAR", true),
|
|
||||||
"COL2_MAPPED" to ColumnType("NUMBER", true),
|
|
||||||
)
|
|
||||||
|
|
||||||
assertEquals(expectedColumns, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun `generateSchemaChanges should correctly identify changes`() {
|
|
||||||
val columnsInDb =
|
|
||||||
setOf(
|
|
||||||
ColumnDefinition("COL1", "VARCHAR"),
|
|
||||||
ColumnDefinition("COL2", "NUMBER"),
|
|
||||||
ColumnDefinition("COL3", "BOOLEAN")
|
|
||||||
)
|
|
||||||
val columnsInStream =
|
|
||||||
setOf(
|
|
||||||
ColumnDefinition("COL1", "VARCHAR"), // Unchanged
|
|
||||||
ColumnDefinition("COL3", "TEXT"), // Modified
|
|
||||||
ColumnDefinition("COL4", "DATE") // Added
|
|
||||||
)
|
|
||||||
|
|
||||||
val (added, deleted, modified) = client.generateSchemaChanges(columnsInDb, columnsInStream)
|
|
||||||
|
|
||||||
assertEquals(1, added.size)
|
|
||||||
assertEquals("COL4", added.first().name)
|
|
||||||
assertEquals(1, deleted.size)
|
|
||||||
assertEquals("COL2", deleted.first().name)
|
|
||||||
assertEquals(1, modified.size)
|
|
||||||
assertEquals("COL3", modified.first().name)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testCreateNamespaceWithNetworkFailure() {
|
fun testCreateNamespaceWithNetworkFailure() {
|
||||||
val namespace = "test_namespace"
|
val namespace = "test_namespace"
|
||||||
|
|||||||
@@ -4,14 +4,19 @@
|
|||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.dataflow
|
package io.airbyte.integrations.destination.snowflake.dataflow
|
||||||
|
|
||||||
|
import io.airbyte.cdk.load.command.DestinationCatalog
|
||||||
|
import io.airbyte.cdk.load.command.DestinationStream
|
||||||
import io.airbyte.cdk.load.command.DestinationStream.Descriptor
|
import io.airbyte.cdk.load.command.DestinationStream.Descriptor
|
||||||
import io.airbyte.cdk.load.dataflow.aggregate.StoreKey
|
import io.airbyte.cdk.load.dataflow.aggregate.StoreKey
|
||||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig
|
import io.airbyte.cdk.load.schema.model.TableName
|
||||||
import io.airbyte.cdk.load.table.TableName
|
import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
|
||||||
import io.airbyte.cdk.load.write.StreamStateStore
|
import io.airbyte.cdk.load.write.StreamStateStore
|
||||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||||
|
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRawRecordFormatter
|
||||||
|
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
|
||||||
|
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter
|
||||||
import io.mockk.every
|
import io.mockk.every
|
||||||
import io.mockk.mockk
|
import io.mockk.mockk
|
||||||
import org.junit.jupiter.api.Assertions.assertEquals
|
import org.junit.jupiter.api.Assertions.assertEquals
|
||||||
@@ -22,27 +27,33 @@ internal class SnowflakeAggregateFactoryTest {
|
|||||||
@Test
|
@Test
|
||||||
fun testCreatingAggregateWithRawBuffer() {
|
fun testCreatingAggregateWithRawBuffer() {
|
||||||
val descriptor = Descriptor(namespace = "namespace", name = "name")
|
val descriptor = Descriptor(namespace = "namespace", name = "name")
|
||||||
val directLoadTableExecutionConfig =
|
val tableName =
|
||||||
DirectLoadTableExecutionConfig(
|
TableName(
|
||||||
tableName =
|
namespace = descriptor.namespace!!,
|
||||||
TableName(
|
name = descriptor.name,
|
||||||
namespace = descriptor.namespace!!,
|
|
||||||
name = descriptor.name,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
val directLoadTableExecutionConfig = DirectLoadTableExecutionConfig(tableName = tableName)
|
||||||
val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name)
|
val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name)
|
||||||
val streamStore = StreamStateStore<DirectLoadTableExecutionConfig>()
|
val streamStore = StreamStateStore<DirectLoadTableExecutionConfig>()
|
||||||
streamStore.put(descriptor, directLoadTableExecutionConfig)
|
streamStore.put(key, directLoadTableExecutionConfig)
|
||||||
|
|
||||||
|
val stream = mockk<DestinationStream>(relaxed = true)
|
||||||
|
val catalog = mockk<DestinationCatalog> { every { getStream(key) } returns stream }
|
||||||
|
|
||||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||||
val snowflakeConfiguration =
|
val snowflakeConfiguration =
|
||||||
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns true }
|
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns true }
|
||||||
val snowflakeColumnUtils = mockk<SnowflakeColumnUtils>(relaxed = true)
|
val columnManager = SnowflakeColumnManager(snowflakeConfiguration)
|
||||||
|
val snowflakeRecordFormatter: SnowflakeRecordFormatter = SnowflakeRawRecordFormatter()
|
||||||
|
|
||||||
val factory =
|
val factory =
|
||||||
SnowflakeAggregateFactory(
|
SnowflakeAggregateFactory(
|
||||||
snowflakeClient = snowflakeClient,
|
snowflakeClient = snowflakeClient,
|
||||||
streamStateStore = streamStore,
|
streamStateStore = streamStore,
|
||||||
snowflakeConfiguration = snowflakeConfiguration,
|
snowflakeConfiguration = snowflakeConfiguration,
|
||||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
catalog = catalog,
|
||||||
|
columnManager = columnManager,
|
||||||
|
snowflakeRecordFormatter = snowflakeRecordFormatter,
|
||||||
)
|
)
|
||||||
val aggregate = factory.create(key)
|
val aggregate = factory.create(key)
|
||||||
assertNotNull(aggregate)
|
assertNotNull(aggregate)
|
||||||
@@ -52,26 +63,33 @@ internal class SnowflakeAggregateFactoryTest {
|
|||||||
@Test
|
@Test
|
||||||
fun testCreatingAggregateWithStagingBuffer() {
|
fun testCreatingAggregateWithStagingBuffer() {
|
||||||
val descriptor = Descriptor(namespace = "namespace", name = "name")
|
val descriptor = Descriptor(namespace = "namespace", name = "name")
|
||||||
val directLoadTableExecutionConfig =
|
val tableName =
|
||||||
DirectLoadTableExecutionConfig(
|
TableName(
|
||||||
tableName =
|
namespace = descriptor.namespace!!,
|
||||||
TableName(
|
name = descriptor.name,
|
||||||
namespace = descriptor.namespace!!,
|
|
||||||
name = descriptor.name,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
val directLoadTableExecutionConfig = DirectLoadTableExecutionConfig(tableName = tableName)
|
||||||
val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name)
|
val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name)
|
||||||
val streamStore = StreamStateStore<DirectLoadTableExecutionConfig>()
|
val streamStore = StreamStateStore<DirectLoadTableExecutionConfig>()
|
||||||
streamStore.put(descriptor, directLoadTableExecutionConfig)
|
streamStore.put(key, directLoadTableExecutionConfig)
|
||||||
|
|
||||||
|
val stream = mockk<DestinationStream>(relaxed = true)
|
||||||
|
val catalog = mockk<DestinationCatalog> { every { getStream(key) } returns stream }
|
||||||
|
|
||||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||||
val snowflakeConfiguration = mockk<SnowflakeConfiguration>(relaxed = true)
|
val snowflakeConfiguration =
|
||||||
val snowflakeColumnUtils = mockk<SnowflakeColumnUtils>(relaxed = true)
|
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns false }
|
||||||
|
val columnManager = SnowflakeColumnManager(snowflakeConfiguration)
|
||||||
|
val snowflakeRecordFormatter: SnowflakeRecordFormatter = SnowflakeSchemaRecordFormatter()
|
||||||
|
|
||||||
val factory =
|
val factory =
|
||||||
SnowflakeAggregateFactory(
|
SnowflakeAggregateFactory(
|
||||||
snowflakeClient = snowflakeClient,
|
snowflakeClient = snowflakeClient,
|
||||||
streamStateStore = streamStore,
|
streamStateStore = streamStore,
|
||||||
snowflakeConfiguration = snowflakeConfiguration,
|
snowflakeConfiguration = snowflakeConfiguration,
|
||||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
catalog = catalog,
|
||||||
|
columnManager = columnManager,
|
||||||
|
snowflakeRecordFormatter = snowflakeRecordFormatter,
|
||||||
)
|
)
|
||||||
val aggregate = factory.create(key)
|
val aggregate = factory.create(key)
|
||||||
assertNotNull(aggregate)
|
assertNotNull(aggregate)
|
||||||
|
|||||||
@@ -1,23 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.db
|
|
||||||
|
|
||||||
import io.mockk.every
|
|
||||||
import io.mockk.mockk
|
|
||||||
import org.junit.jupiter.api.Assertions.assertEquals
|
|
||||||
import org.junit.jupiter.api.Test
|
|
||||||
|
|
||||||
internal class SnowflakeColumnNameGeneratorTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testGetColumnName() {
|
|
||||||
val column = "test-column"
|
|
||||||
val generator =
|
|
||||||
SnowflakeColumnNameGenerator(mockk { every { legacyRawTablesOnly } returns false })
|
|
||||||
val columnName = generator.getColumnName(column)
|
|
||||||
assertEquals(column.toSnowflakeCompatibleName(), columnName.displayName)
|
|
||||||
assertEquals(column.toSnowflakeCompatibleName(), columnName.canonicalName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.db
|
|
||||||
|
|
||||||
import io.airbyte.cdk.load.command.DestinationStream
|
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
|
||||||
import io.mockk.every
|
|
||||||
import io.mockk.mockk
|
|
||||||
import org.junit.jupiter.api.Assertions.assertEquals
|
|
||||||
import org.junit.jupiter.api.Test
|
|
||||||
|
|
||||||
internal class SnowflakeFinalTableNameGeneratorTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testGetTableNameWithInternalNamespace() {
|
|
||||||
val configuration =
|
|
||||||
mockk<SnowflakeConfiguration> {
|
|
||||||
every { internalTableSchema } returns "test-internal-namespace"
|
|
||||||
every { legacyRawTablesOnly } returns true
|
|
||||||
}
|
|
||||||
val generator = SnowflakeFinalTableNameGenerator(config = configuration)
|
|
||||||
val streamName = "test-stream-name"
|
|
||||||
val streamNamespace = "test-stream-namespace"
|
|
||||||
val streamDescriptor =
|
|
||||||
mockk<DestinationStream.Descriptor> {
|
|
||||||
every { namespace } returns streamNamespace
|
|
||||||
every { name } returns streamName
|
|
||||||
}
|
|
||||||
val tableName = generator.getTableName(streamDescriptor)
|
|
||||||
assertEquals("test-stream-namespace_raw__stream_test-stream-name", tableName.name)
|
|
||||||
assertEquals("test-internal-namespace", tableName.namespace)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testGetTableNameWithNamespace() {
|
|
||||||
val configuration =
|
|
||||||
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns false }
|
|
||||||
val generator = SnowflakeFinalTableNameGenerator(config = configuration)
|
|
||||||
val streamName = "test-stream-name"
|
|
||||||
val streamNamespace = "test-stream-namespace"
|
|
||||||
val streamDescriptor =
|
|
||||||
mockk<DestinationStream.Descriptor> {
|
|
||||||
every { namespace } returns streamNamespace
|
|
||||||
every { name } returns streamName
|
|
||||||
}
|
|
||||||
val tableName = generator.getTableName(streamDescriptor)
|
|
||||||
assertEquals("TEST-STREAM-NAME", tableName.name)
|
|
||||||
assertEquals("TEST-STREAM-NAMESPACE", tableName.namespace)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testGetTableNameWithDefaultNamespace() {
|
|
||||||
val defaultNamespace = "test-default-namespace"
|
|
||||||
val configuration =
|
|
||||||
mockk<SnowflakeConfiguration> {
|
|
||||||
every { schema } returns defaultNamespace
|
|
||||||
every { legacyRawTablesOnly } returns false
|
|
||||||
}
|
|
||||||
val generator = SnowflakeFinalTableNameGenerator(config = configuration)
|
|
||||||
val streamName = "test-stream-name"
|
|
||||||
val streamDescriptor =
|
|
||||||
mockk<DestinationStream.Descriptor> {
|
|
||||||
every { namespace } returns null
|
|
||||||
every { name } returns streamName
|
|
||||||
}
|
|
||||||
val tableName = generator.getTableName(streamDescriptor)
|
|
||||||
assertEquals("TEST-STREAM-NAME", tableName.name)
|
|
||||||
assertEquals("TEST-DEFAULT-NAMESPACE", tableName.namespace)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.db
|
|
||||||
|
|
||||||
import org.junit.jupiter.api.Assertions.assertEquals
|
|
||||||
import org.junit.jupiter.params.ParameterizedTest
|
|
||||||
import org.junit.jupiter.params.provider.CsvSource
|
|
||||||
|
|
||||||
internal class SnowflakeNameGeneratorsTest {
|
|
||||||
|
|
||||||
@ParameterizedTest
|
|
||||||
@CsvSource(
|
|
||||||
value =
|
|
||||||
[
|
|
||||||
"test-name,TEST-NAME",
|
|
||||||
"1-test-name,1-TEST-NAME",
|
|
||||||
"test-name!!!,TEST-NAME!!!",
|
|
||||||
"test\${name,TEST__NAME",
|
|
||||||
"test\"name,TEST\"\"NAME",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
fun testToSnowflakeCompatibleName(name: String, expected: String) {
|
|
||||||
assertEquals(expected, name.toSnowflakeCompatibleName())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,358 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.sql
|
|
||||||
|
|
||||||
import com.fasterxml.jackson.databind.JsonNode
|
|
||||||
import io.airbyte.cdk.load.data.ArrayType
|
|
||||||
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
|
|
||||||
import io.airbyte.cdk.load.data.BooleanType
|
|
||||||
import io.airbyte.cdk.load.data.DateType
|
|
||||||
import io.airbyte.cdk.load.data.FieldType
|
|
||||||
import io.airbyte.cdk.load.data.IntegerType
|
|
||||||
import io.airbyte.cdk.load.data.NumberType
|
|
||||||
import io.airbyte.cdk.load.data.ObjectType
|
|
||||||
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
|
|
||||||
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
|
|
||||||
import io.airbyte.cdk.load.data.StringType
|
|
||||||
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
|
|
||||||
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
|
|
||||||
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
|
|
||||||
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
|
|
||||||
import io.airbyte.cdk.load.data.UnionType
|
|
||||||
import io.airbyte.cdk.load.data.UnknownType
|
|
||||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA
|
|
||||||
import io.airbyte.cdk.load.orchestration.db.ColumnNameGenerator
|
|
||||||
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
|
|
||||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
|
||||||
import io.airbyte.integrations.destination.snowflake.db.SnowflakeColumnNameGenerator
|
|
||||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
|
||||||
import io.mockk.every
|
|
||||||
import io.mockk.mockk
|
|
||||||
import kotlin.collections.LinkedHashMap
|
|
||||||
import org.junit.jupiter.api.Assertions.assertEquals
|
|
||||||
import org.junit.jupiter.api.BeforeEach
|
|
||||||
import org.junit.jupiter.api.Test
|
|
||||||
import org.junit.jupiter.params.ParameterizedTest
|
|
||||||
import org.junit.jupiter.params.provider.CsvSource
|
|
||||||
|
|
||||||
internal class SnowflakeColumnUtilsTest {
|
|
||||||
|
|
||||||
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
|
|
||||||
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
|
|
||||||
private lateinit var snowflakeColumnNameGenerator: SnowflakeColumnNameGenerator
|
|
||||||
|
|
||||||
@BeforeEach
|
|
||||||
fun setup() {
|
|
||||||
snowflakeConfiguration = mockk(relaxed = true)
|
|
||||||
snowflakeColumnNameGenerator =
|
|
||||||
mockk(relaxed = true) {
|
|
||||||
every { getColumnName(any()) } answers
|
|
||||||
{
|
|
||||||
val displayName =
|
|
||||||
if (snowflakeConfiguration.legacyRawTablesOnly) firstArg<String>()
|
|
||||||
else firstArg<String>().toSnowflakeCompatibleName()
|
|
||||||
val canonicalName =
|
|
||||||
if (snowflakeConfiguration.legacyRawTablesOnly) firstArg<String>()
|
|
||||||
else firstArg<String>().toSnowflakeCompatibleName()
|
|
||||||
ColumnNameGenerator.ColumnName(
|
|
||||||
displayName = displayName,
|
|
||||||
canonicalName = canonicalName,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
snowflakeColumnUtils =
|
|
||||||
SnowflakeColumnUtils(snowflakeConfiguration, snowflakeColumnNameGenerator)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testDefaultColumns() {
|
|
||||||
val expectedDefaultColumns = DEFAULT_COLUMNS
|
|
||||||
assertEquals(expectedDefaultColumns, snowflakeColumnUtils.defaultColumns())
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testDefaultRawColumns() {
|
|
||||||
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
|
|
||||||
|
|
||||||
val expectedDefaultColumns = DEFAULT_COLUMNS + RAW_COLUMNS
|
|
||||||
|
|
||||||
assertEquals(expectedDefaultColumns, snowflakeColumnUtils.defaultColumns())
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testGetFormattedDefaultColumnNames() {
|
|
||||||
val expectedDefaultColumnNames =
|
|
||||||
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() }
|
|
||||||
val defaultColumnNames = snowflakeColumnUtils.getFormattedDefaultColumnNames()
|
|
||||||
assertEquals(expectedDefaultColumnNames, defaultColumnNames)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testGetFormattedDefaultColumnNamesQuoted() {
|
|
||||||
val expectedDefaultColumnNames =
|
|
||||||
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName().quote() }
|
|
||||||
val defaultColumnNames = snowflakeColumnUtils.getFormattedDefaultColumnNames(true)
|
|
||||||
assertEquals(expectedDefaultColumnNames, defaultColumnNames)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testGetColumnName() {
|
|
||||||
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
|
|
||||||
val columnNames = snowflakeColumnUtils.getColumnNames(columnNameMapping)
|
|
||||||
val expectedColumnNames =
|
|
||||||
(DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() } + listOf("actual"))
|
|
||||||
.joinToString(",") { it.quote() }
|
|
||||||
assertEquals(expectedColumnNames, columnNames)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testGetRawColumnName() {
|
|
||||||
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
|
|
||||||
|
|
||||||
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
|
|
||||||
val columnNames = snowflakeColumnUtils.getColumnNames(columnNameMapping)
|
|
||||||
val expectedColumnNames =
|
|
||||||
(DEFAULT_COLUMNS.map { it.columnName } + RAW_COLUMNS.map { it.columnName })
|
|
||||||
.joinToString(",") { it.quote() }
|
|
||||||
assertEquals(expectedColumnNames, columnNames)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testGetRawFormattedColumnNames() {
|
|
||||||
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
|
|
||||||
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
|
|
||||||
val schemaColumns =
|
|
||||||
mapOf(
|
|
||||||
"column_one" to FieldType(StringType, true),
|
|
||||||
"column_two" to FieldType(IntegerType, true),
|
|
||||||
"original" to FieldType(StringType, true),
|
|
||||||
CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true)
|
|
||||||
)
|
|
||||||
val expectedColumnNames =
|
|
||||||
DEFAULT_COLUMNS.map { it.columnName.quote() } +
|
|
||||||
RAW_COLUMNS.map { it.columnName.quote() }
|
|
||||||
|
|
||||||
val columnNames =
|
|
||||||
snowflakeColumnUtils.getFormattedColumnNames(
|
|
||||||
columns = schemaColumns,
|
|
||||||
columnNameMapping = columnNameMapping
|
|
||||||
)
|
|
||||||
assertEquals(expectedColumnNames.size, columnNames.size)
|
|
||||||
assertEquals(expectedColumnNames.sorted(), columnNames.sorted())
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testGetFormattedColumnNames() {
|
|
||||||
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
|
|
||||||
val schemaColumns =
|
|
||||||
mapOf(
|
|
||||||
"column_one" to FieldType(StringType, true),
|
|
||||||
"column_two" to FieldType(IntegerType, true),
|
|
||||||
"original" to FieldType(StringType, true),
|
|
||||||
CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true)
|
|
||||||
)
|
|
||||||
val expectedColumnNames =
|
|
||||||
listOf(
|
|
||||||
"actual",
|
|
||||||
"column_one",
|
|
||||||
"column_two",
|
|
||||||
CDC_DELETED_AT_COLUMN,
|
|
||||||
)
|
|
||||||
.map { it.quote() } +
|
|
||||||
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName().quote() }
|
|
||||||
val columnNames =
|
|
||||||
snowflakeColumnUtils.getFormattedColumnNames(
|
|
||||||
columns = schemaColumns,
|
|
||||||
columnNameMapping = columnNameMapping
|
|
||||||
)
|
|
||||||
assertEquals(expectedColumnNames.size, columnNames.size)
|
|
||||||
assertEquals(expectedColumnNames.sorted(), columnNames.sorted())
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testGetFormattedColumnNamesNoQuotes() {
|
|
||||||
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
|
|
||||||
val schemaColumns =
|
|
||||||
mapOf(
|
|
||||||
"column_one" to FieldType(StringType, true),
|
|
||||||
"column_two" to FieldType(IntegerType, true),
|
|
||||||
"original" to FieldType(StringType, true),
|
|
||||||
CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true)
|
|
||||||
)
|
|
||||||
val expectedColumnNames =
|
|
||||||
listOf(
|
|
||||||
"actual",
|
|
||||||
"column_one",
|
|
||||||
"column_two",
|
|
||||||
CDC_DELETED_AT_COLUMN,
|
|
||||||
) + DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() }
|
|
||||||
val columnNames =
|
|
||||||
snowflakeColumnUtils.getFormattedColumnNames(
|
|
||||||
columns = schemaColumns,
|
|
||||||
columnNameMapping = columnNameMapping,
|
|
||||||
quote = false
|
|
||||||
)
|
|
||||||
assertEquals(expectedColumnNames.size, columnNames.size)
|
|
||||||
assertEquals(expectedColumnNames.sorted(), columnNames.sorted())
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testGeneratingRawTableColumnsAndTypesNoColumnMapping() {
|
|
||||||
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
|
|
||||||
|
|
||||||
val columns =
|
|
||||||
snowflakeColumnUtils.columnsAndTypes(
|
|
||||||
columns = emptyMap(),
|
|
||||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
|
||||||
)
|
|
||||||
assertEquals(DEFAULT_COLUMNS.size + RAW_COLUMNS.size, columns.size)
|
|
||||||
assertEquals(
|
|
||||||
"${SnowflakeDataType.VARIANT.typeName} $NOT_NULL",
|
|
||||||
columns.find { it.columnName == RAW_DATA_COLUMN.columnName }?.columnType
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testGeneratingColumnsAndTypesNoColumnMapping() {
|
|
||||||
val columnName = "test-column"
|
|
||||||
val fieldType = FieldType(StringType, false)
|
|
||||||
val declaredColumns = mapOf(columnName to fieldType)
|
|
||||||
|
|
||||||
val columns =
|
|
||||||
snowflakeColumnUtils.columnsAndTypes(
|
|
||||||
columns = declaredColumns,
|
|
||||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
|
||||||
)
|
|
||||||
assertEquals(DEFAULT_COLUMNS.size + 1, columns.size)
|
|
||||||
assertEquals(
|
|
||||||
"${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL",
|
|
||||||
columns.find { it.columnName == columnName }?.columnType
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testGeneratingColumnsAndTypesWithColumnMapping() {
|
|
||||||
val columnName = "test-column"
|
|
||||||
val mappedColumnName = "mapped-column-name"
|
|
||||||
val fieldType = FieldType(StringType, false)
|
|
||||||
val declaredColumns = mapOf(columnName to fieldType)
|
|
||||||
val columnNameMapping = ColumnNameMapping(mapOf(columnName to mappedColumnName))
|
|
||||||
|
|
||||||
val columns =
|
|
||||||
snowflakeColumnUtils.columnsAndTypes(
|
|
||||||
columns = declaredColumns,
|
|
||||||
columnNameMapping = columnNameMapping
|
|
||||||
)
|
|
||||||
assertEquals(DEFAULT_COLUMNS.size + 1, columns.size)
|
|
||||||
assertEquals(
|
|
||||||
"${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL",
|
|
||||||
columns.find { it.columnName == mappedColumnName }?.columnType
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testToDialectType() {
|
|
||||||
assertEquals(
|
|
||||||
SnowflakeDataType.BOOLEAN.typeName,
|
|
||||||
snowflakeColumnUtils.toDialectType(BooleanType)
|
|
||||||
)
|
|
||||||
assertEquals(SnowflakeDataType.DATE.typeName, snowflakeColumnUtils.toDialectType(DateType))
|
|
||||||
assertEquals(
|
|
||||||
SnowflakeDataType.NUMBER.typeName,
|
|
||||||
snowflakeColumnUtils.toDialectType(IntegerType)
|
|
||||||
)
|
|
||||||
assertEquals(
|
|
||||||
SnowflakeDataType.FLOAT.typeName,
|
|
||||||
snowflakeColumnUtils.toDialectType(NumberType)
|
|
||||||
)
|
|
||||||
assertEquals(
|
|
||||||
SnowflakeDataType.VARCHAR.typeName,
|
|
||||||
snowflakeColumnUtils.toDialectType(StringType)
|
|
||||||
)
|
|
||||||
assertEquals(
|
|
||||||
SnowflakeDataType.VARCHAR.typeName,
|
|
||||||
snowflakeColumnUtils.toDialectType(TimeTypeWithTimezone)
|
|
||||||
)
|
|
||||||
assertEquals(
|
|
||||||
SnowflakeDataType.TIME.typeName,
|
|
||||||
snowflakeColumnUtils.toDialectType(TimeTypeWithoutTimezone)
|
|
||||||
)
|
|
||||||
assertEquals(
|
|
||||||
SnowflakeDataType.TIMESTAMP_TZ.typeName,
|
|
||||||
snowflakeColumnUtils.toDialectType(TimestampTypeWithTimezone)
|
|
||||||
)
|
|
||||||
assertEquals(
|
|
||||||
SnowflakeDataType.TIMESTAMP_NTZ.typeName,
|
|
||||||
snowflakeColumnUtils.toDialectType(TimestampTypeWithoutTimezone)
|
|
||||||
)
|
|
||||||
assertEquals(
|
|
||||||
SnowflakeDataType.ARRAY.typeName,
|
|
||||||
snowflakeColumnUtils.toDialectType(ArrayType(items = FieldType(StringType, false)))
|
|
||||||
)
|
|
||||||
assertEquals(
|
|
||||||
SnowflakeDataType.ARRAY.typeName,
|
|
||||||
snowflakeColumnUtils.toDialectType(ArrayTypeWithoutSchema)
|
|
||||||
)
|
|
||||||
assertEquals(
|
|
||||||
SnowflakeDataType.OBJECT.typeName,
|
|
||||||
snowflakeColumnUtils.toDialectType(
|
|
||||||
ObjectType(
|
|
||||||
properties = LinkedHashMap(),
|
|
||||||
additionalProperties = false,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assertEquals(
|
|
||||||
SnowflakeDataType.OBJECT.typeName,
|
|
||||||
snowflakeColumnUtils.toDialectType(ObjectTypeWithEmptySchema)
|
|
||||||
)
|
|
||||||
assertEquals(
|
|
||||||
SnowflakeDataType.OBJECT.typeName,
|
|
||||||
snowflakeColumnUtils.toDialectType(ObjectTypeWithoutSchema)
|
|
||||||
)
|
|
||||||
assertEquals(
|
|
||||||
SnowflakeDataType.VARIANT.typeName,
|
|
||||||
snowflakeColumnUtils.toDialectType(
|
|
||||||
UnionType(
|
|
||||||
options = setOf(StringType),
|
|
||||||
isLegacyUnion = true,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assertEquals(
|
|
||||||
SnowflakeDataType.VARIANT.typeName,
|
|
||||||
snowflakeColumnUtils.toDialectType(
|
|
||||||
UnionType(
|
|
||||||
options = emptySet(),
|
|
||||||
isLegacyUnion = false,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assertEquals(
|
|
||||||
SnowflakeDataType.VARIANT.typeName,
|
|
||||||
snowflakeColumnUtils.toDialectType(UnknownType(schema = mockk<JsonNode>()))
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
@ParameterizedTest
|
|
||||||
@CsvSource(
|
|
||||||
value =
|
|
||||||
[
|
|
||||||
"$COLUMN_NAME_DATA, true, \"$COLUMN_NAME_DATA\"",
|
|
||||||
"some-other_Column, true, \"SOME-OTHER_COLUMN\"",
|
|
||||||
"$COLUMN_NAME_DATA, false, $COLUMN_NAME_DATA",
|
|
||||||
"some-other_Column, false, SOME-OTHER_COLUMN",
|
|
||||||
"$COLUMN_NAME_DATA, true, \"$COLUMN_NAME_DATA\"",
|
|
||||||
"some-other_Column, true, \"SOME-OTHER_COLUMN\"",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
fun testFormatColumnName(columnName: String, quote: Boolean, expectedFormattedName: String) {
|
|
||||||
assertEquals(
|
|
||||||
expectedFormattedName,
|
|
||||||
snowflakeColumnUtils.formatColumnName(columnName, quote)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,97 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.sql
|
|
||||||
|
|
||||||
import io.airbyte.cdk.load.table.TableName
|
|
||||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
|
||||||
import io.mockk.every
|
|
||||||
import io.mockk.mockk
|
|
||||||
import org.junit.jupiter.api.Assertions.assertEquals
|
|
||||||
import org.junit.jupiter.api.BeforeEach
|
|
||||||
import org.junit.jupiter.api.Test
|
|
||||||
|
|
||||||
internal class SnowflakeSqlNameUtilsTest {
|
|
||||||
|
|
||||||
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
|
|
||||||
private lateinit var snowflakeSqlNameUtils: SnowflakeSqlNameUtils
|
|
||||||
|
|
||||||
@BeforeEach
|
|
||||||
fun setUp() {
|
|
||||||
snowflakeConfiguration = mockk(relaxed = true)
|
|
||||||
snowflakeSqlNameUtils =
|
|
||||||
SnowflakeSqlNameUtils(snowflakeConfiguration = snowflakeConfiguration)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testFullyQualifiedName() {
|
|
||||||
val databaseName = "test-database"
|
|
||||||
val namespace = "test-namespace"
|
|
||||||
val name = "test=name"
|
|
||||||
val tableName = TableName(namespace = namespace, name = name)
|
|
||||||
every { snowflakeConfiguration.database } returns databaseName
|
|
||||||
|
|
||||||
val expectedName =
|
|
||||||
snowflakeSqlNameUtils.combineParts(
|
|
||||||
listOf(
|
|
||||||
databaseName.toSnowflakeCompatibleName(),
|
|
||||||
tableName.namespace,
|
|
||||||
tableName.name
|
|
||||||
)
|
|
||||||
)
|
|
||||||
val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedName(tableName)
|
|
||||||
assertEquals(expectedName, fullyQualifiedName)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testFullyQualifiedNamespace() {
|
|
||||||
val databaseName = "test-database"
|
|
||||||
val namespace = "test-namespace"
|
|
||||||
every { snowflakeConfiguration.database } returns databaseName
|
|
||||||
|
|
||||||
val fullyQualifiedNamespace = snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)
|
|
||||||
assertEquals("\"TEST-DATABASE\".\"test-namespace\"", fullyQualifiedNamespace)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testFullyQualifiedStageName() {
|
|
||||||
val databaseName = "test-database"
|
|
||||||
val namespace = "test-namespace"
|
|
||||||
val name = "test=name"
|
|
||||||
val tableName = TableName(namespace = namespace, name = name)
|
|
||||||
every { snowflakeConfiguration.database } returns databaseName
|
|
||||||
|
|
||||||
val expectedName =
|
|
||||||
snowflakeSqlNameUtils.combineParts(
|
|
||||||
listOf(
|
|
||||||
databaseName.toSnowflakeCompatibleName(),
|
|
||||||
namespace,
|
|
||||||
"$STAGE_NAME_PREFIX$name"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName)
|
|
||||||
assertEquals(expectedName, fullyQualifiedName)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testFullyQualifiedStageNameWithEscape() {
|
|
||||||
val databaseName = "test-database"
|
|
||||||
val namespace = "test-namespace"
|
|
||||||
val name = "test=\"\"\'name"
|
|
||||||
val tableName = TableName(namespace = namespace, name = name)
|
|
||||||
every { snowflakeConfiguration.database } returns databaseName
|
|
||||||
|
|
||||||
val expectedName =
|
|
||||||
snowflakeSqlNameUtils.combineParts(
|
|
||||||
listOf(
|
|
||||||
databaseName.toSnowflakeCompatibleName(),
|
|
||||||
namespace,
|
|
||||||
"$STAGE_NAME_PREFIX${sqlEscape(name)}"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true)
|
|
||||||
assertEquals(expectedName, fullyQualifiedName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -6,20 +6,21 @@ package io.airbyte.integrations.destination.snowflake.write
|
|||||||
|
|
||||||
import io.airbyte.cdk.SystemErrorException
|
import io.airbyte.cdk.SystemErrorException
|
||||||
import io.airbyte.cdk.load.command.Append
|
import io.airbyte.cdk.load.command.Append
|
||||||
|
import io.airbyte.cdk.load.command.DestinationCatalog
|
||||||
import io.airbyte.cdk.load.command.DestinationStream
|
import io.airbyte.cdk.load.command.DestinationStream
|
||||||
import io.airbyte.cdk.load.orchestration.db.DatabaseInitialStatusGatherer
|
import io.airbyte.cdk.load.schema.model.ColumnSchema
|
||||||
import io.airbyte.cdk.load.orchestration.db.TableNames
|
import io.airbyte.cdk.load.schema.model.StreamTableSchema
|
||||||
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator
|
import io.airbyte.cdk.load.schema.model.TableName
|
||||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadInitialStatus
|
import io.airbyte.cdk.load.schema.model.TableNames
|
||||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendStreamLoader
|
import io.airbyte.cdk.load.table.DatabaseInitialStatusGatherer
|
||||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendTruncateStreamLoader
|
import io.airbyte.cdk.load.table.TempTableNameGenerator
|
||||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableStatus
|
import io.airbyte.cdk.load.table.directload.DirectLoadInitialStatus
|
||||||
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog
|
import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendStreamLoader
|
||||||
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableNameInfo
|
import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendTruncateStreamLoader
|
||||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
|
||||||
import io.airbyte.cdk.load.table.TableName
|
import io.airbyte.cdk.load.table.directload.DirectLoadTableStatus
|
||||||
|
import io.airbyte.cdk.load.write.StreamStateStore
|
||||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
|
||||||
import io.mockk.coEvery
|
import io.mockk.coEvery
|
||||||
import io.mockk.coVerify
|
import io.mockk.coVerify
|
||||||
import io.mockk.every
|
import io.mockk.every
|
||||||
@@ -34,55 +35,93 @@ internal class SnowflakeWriterTest {
|
|||||||
@Test
|
@Test
|
||||||
fun testSetup() {
|
fun testSetup() {
|
||||||
val tableName = TableName(namespace = "test-namespace", name = "test-name")
|
val tableName = TableName(namespace = "test-namespace", name = "test-name")
|
||||||
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
|
val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
|
||||||
val stream = mockk<DestinationStream>()
|
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
|
||||||
val tableInfo =
|
val stream =
|
||||||
TableNameInfo(
|
mockk<DestinationStream> {
|
||||||
tableNames = tableNames,
|
every { tableSchema } returns
|
||||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
StreamTableSchema(
|
||||||
)
|
tableNames = tableNames,
|
||||||
val catalog = TableCatalog(mapOf(stream to tableInfo))
|
columnSchema =
|
||||||
|
ColumnSchema(
|
||||||
|
inputToFinalColumnNames = emptyMap(),
|
||||||
|
finalSchema = emptyMap(),
|
||||||
|
inputSchema = emptyMap()
|
||||||
|
),
|
||||||
|
importType = Append
|
||||||
|
)
|
||||||
|
every { mappedDescriptor } returns
|
||||||
|
DestinationStream.Descriptor(
|
||||||
|
namespace = tableName.namespace,
|
||||||
|
name = tableName.name
|
||||||
|
)
|
||||||
|
every { importType } returns Append
|
||||||
|
}
|
||||||
|
val catalog = DestinationCatalog(listOf(stream))
|
||||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||||
val stateGatherer =
|
val stateGatherer =
|
||||||
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
|
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
|
||||||
coEvery { gatherInitialStatus(catalog) } returns emptyMap()
|
coEvery { gatherInitialStatus() } returns
|
||||||
|
mapOf(
|
||||||
|
stream to
|
||||||
|
DirectLoadInitialStatus(
|
||||||
|
realTable = DirectLoadTableStatus(false),
|
||||||
|
tempTable = null
|
||||||
|
)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
|
||||||
val writer =
|
val writer =
|
||||||
SnowflakeWriter(
|
SnowflakeWriter(
|
||||||
names = catalog,
|
catalog = catalog,
|
||||||
stateGatherer = stateGatherer,
|
stateGatherer = stateGatherer,
|
||||||
streamStateStore = mockk(),
|
streamStateStore = streamStateStore,
|
||||||
snowflakeClient = snowflakeClient,
|
snowflakeClient = snowflakeClient,
|
||||||
tempTableNameGenerator = mockk(),
|
tempTableNameGenerator = mockk(),
|
||||||
snowflakeConfiguration = mockk(relaxed = true),
|
snowflakeConfiguration =
|
||||||
|
mockk(relaxed = true) {
|
||||||
|
every { internalTableSchema } returns "internal_schema"
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
runBlocking { writer.setup() }
|
runBlocking { writer.setup() }
|
||||||
|
|
||||||
coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) }
|
coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) }
|
||||||
coVerify(exactly = 1) { stateGatherer.gatherInitialStatus(catalog) }
|
coVerify(exactly = 1) { stateGatherer.gatherInitialStatus() }
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testCreateStreamLoaderFirstGeneration() {
|
fun testCreateStreamLoaderFirstGeneration() {
|
||||||
val tableName = TableName(namespace = "test-namespace", name = "test-name")
|
val tableName = TableName(namespace = "test-namespace", name = "test-name")
|
||||||
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
|
val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
|
||||||
|
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
|
||||||
val stream =
|
val stream =
|
||||||
mockk<DestinationStream> {
|
mockk<DestinationStream> {
|
||||||
every { minimumGenerationId } returns 0L
|
every { minimumGenerationId } returns 0L
|
||||||
every { generationId } returns 0L
|
every { generationId } returns 0L
|
||||||
every { importType } returns Append
|
every { importType } returns Append
|
||||||
|
every { tableSchema } returns
|
||||||
|
StreamTableSchema(
|
||||||
|
tableNames = tableNames,
|
||||||
|
columnSchema =
|
||||||
|
ColumnSchema(
|
||||||
|
inputToFinalColumnNames = emptyMap(),
|
||||||
|
finalSchema = emptyMap(),
|
||||||
|
inputSchema = emptyMap()
|
||||||
|
),
|
||||||
|
importType = Append
|
||||||
|
)
|
||||||
|
every { mappedDescriptor } returns
|
||||||
|
DestinationStream.Descriptor(
|
||||||
|
namespace = tableName.namespace,
|
||||||
|
name = tableName.name
|
||||||
|
)
|
||||||
}
|
}
|
||||||
val tableInfo =
|
val catalog = DestinationCatalog(listOf(stream))
|
||||||
TableNameInfo(
|
|
||||||
tableNames = tableNames,
|
|
||||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
|
||||||
)
|
|
||||||
val catalog = TableCatalog(mapOf(stream to tableInfo))
|
|
||||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||||
val stateGatherer =
|
val stateGatherer =
|
||||||
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
|
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
|
||||||
coEvery { gatherInitialStatus(catalog) } returns
|
coEvery { gatherInitialStatus() } returns
|
||||||
mapOf(
|
mapOf(
|
||||||
stream to
|
stream to
|
||||||
DirectLoadInitialStatus(
|
DirectLoadInitialStatus(
|
||||||
@@ -93,14 +132,18 @@ internal class SnowflakeWriterTest {
|
|||||||
}
|
}
|
||||||
val tempTableNameGenerator =
|
val tempTableNameGenerator =
|
||||||
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
|
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
|
||||||
|
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
|
||||||
val writer =
|
val writer =
|
||||||
SnowflakeWriter(
|
SnowflakeWriter(
|
||||||
names = catalog,
|
catalog = catalog,
|
||||||
stateGatherer = stateGatherer,
|
stateGatherer = stateGatherer,
|
||||||
streamStateStore = mockk(),
|
streamStateStore = streamStateStore,
|
||||||
snowflakeClient = snowflakeClient,
|
snowflakeClient = snowflakeClient,
|
||||||
tempTableNameGenerator = tempTableNameGenerator,
|
tempTableNameGenerator = tempTableNameGenerator,
|
||||||
snowflakeConfiguration = mockk(relaxed = true),
|
snowflakeConfiguration =
|
||||||
|
mockk(relaxed = true) {
|
||||||
|
every { internalTableSchema } returns "internal_schema"
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
runBlocking {
|
runBlocking {
|
||||||
@@ -113,23 +156,35 @@ internal class SnowflakeWriterTest {
|
|||||||
@Test
|
@Test
|
||||||
fun testCreateStreamLoaderNotFirstGeneration() {
|
fun testCreateStreamLoaderNotFirstGeneration() {
|
||||||
val tableName = TableName(namespace = "test-namespace", name = "test-name")
|
val tableName = TableName(namespace = "test-namespace", name = "test-name")
|
||||||
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
|
val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
|
||||||
|
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
|
||||||
val stream =
|
val stream =
|
||||||
mockk<DestinationStream> {
|
mockk<DestinationStream> {
|
||||||
every { minimumGenerationId } returns 1L
|
every { minimumGenerationId } returns 1L
|
||||||
every { generationId } returns 1L
|
every { generationId } returns 1L
|
||||||
every { importType } returns Append
|
every { importType } returns Append
|
||||||
|
every { tableSchema } returns
|
||||||
|
StreamTableSchema(
|
||||||
|
tableNames = tableNames,
|
||||||
|
columnSchema =
|
||||||
|
ColumnSchema(
|
||||||
|
inputToFinalColumnNames = emptyMap(),
|
||||||
|
finalSchema = emptyMap(),
|
||||||
|
inputSchema = emptyMap()
|
||||||
|
),
|
||||||
|
importType = Append
|
||||||
|
)
|
||||||
|
every { mappedDescriptor } returns
|
||||||
|
DestinationStream.Descriptor(
|
||||||
|
namespace = tableName.namespace,
|
||||||
|
name = tableName.name
|
||||||
|
)
|
||||||
}
|
}
|
||||||
val tableInfo =
|
val catalog = DestinationCatalog(listOf(stream))
|
||||||
TableNameInfo(
|
|
||||||
tableNames = tableNames,
|
|
||||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
|
||||||
)
|
|
||||||
val catalog = TableCatalog(mapOf(stream to tableInfo))
|
|
||||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||||
val stateGatherer =
|
val stateGatherer =
|
||||||
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
|
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
|
||||||
coEvery { gatherInitialStatus(catalog) } returns
|
coEvery { gatherInitialStatus() } returns
|
||||||
mapOf(
|
mapOf(
|
||||||
stream to
|
stream to
|
||||||
DirectLoadInitialStatus(
|
DirectLoadInitialStatus(
|
||||||
@@ -140,14 +195,18 @@ internal class SnowflakeWriterTest {
|
|||||||
}
|
}
|
||||||
val tempTableNameGenerator =
|
val tempTableNameGenerator =
|
||||||
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
|
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
|
||||||
|
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
|
||||||
val writer =
|
val writer =
|
||||||
SnowflakeWriter(
|
SnowflakeWriter(
|
||||||
names = catalog,
|
catalog = catalog,
|
||||||
stateGatherer = stateGatherer,
|
stateGatherer = stateGatherer,
|
||||||
streamStateStore = mockk(),
|
streamStateStore = streamStateStore,
|
||||||
snowflakeClient = snowflakeClient,
|
snowflakeClient = snowflakeClient,
|
||||||
tempTableNameGenerator = tempTableNameGenerator,
|
tempTableNameGenerator = tempTableNameGenerator,
|
||||||
snowflakeConfiguration = mockk(relaxed = true),
|
snowflakeConfiguration =
|
||||||
|
mockk(relaxed = true) {
|
||||||
|
every { internalTableSchema } returns "internal_schema"
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
runBlocking {
|
runBlocking {
|
||||||
@@ -160,22 +219,35 @@ internal class SnowflakeWriterTest {
|
|||||||
@Test
|
@Test
|
||||||
fun testCreateStreamLoaderHybrid() {
|
fun testCreateStreamLoaderHybrid() {
|
||||||
val tableName = TableName(namespace = "test-namespace", name = "test-name")
|
val tableName = TableName(namespace = "test-namespace", name = "test-name")
|
||||||
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
|
val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
|
||||||
|
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
|
||||||
val stream =
|
val stream =
|
||||||
mockk<DestinationStream> {
|
mockk<DestinationStream> {
|
||||||
every { minimumGenerationId } returns 1L
|
every { minimumGenerationId } returns 1L
|
||||||
every { generationId } returns 2L
|
every { generationId } returns 2L
|
||||||
|
every { importType } returns Append
|
||||||
|
every { tableSchema } returns
|
||||||
|
StreamTableSchema(
|
||||||
|
tableNames = tableNames,
|
||||||
|
columnSchema =
|
||||||
|
ColumnSchema(
|
||||||
|
inputToFinalColumnNames = emptyMap(),
|
||||||
|
finalSchema = emptyMap(),
|
||||||
|
inputSchema = emptyMap()
|
||||||
|
),
|
||||||
|
importType = Append
|
||||||
|
)
|
||||||
|
every { mappedDescriptor } returns
|
||||||
|
DestinationStream.Descriptor(
|
||||||
|
namespace = tableName.namespace,
|
||||||
|
name = tableName.name
|
||||||
|
)
|
||||||
}
|
}
|
||||||
val tableInfo =
|
val catalog = DestinationCatalog(listOf(stream))
|
||||||
TableNameInfo(
|
|
||||||
tableNames = tableNames,
|
|
||||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
|
||||||
)
|
|
||||||
val catalog = TableCatalog(mapOf(stream to tableInfo))
|
|
||||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||||
val stateGatherer =
|
val stateGatherer =
|
||||||
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
|
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
|
||||||
coEvery { gatherInitialStatus(catalog) } returns
|
coEvery { gatherInitialStatus() } returns
|
||||||
mapOf(
|
mapOf(
|
||||||
stream to
|
stream to
|
||||||
DirectLoadInitialStatus(
|
DirectLoadInitialStatus(
|
||||||
@@ -186,14 +258,18 @@ internal class SnowflakeWriterTest {
|
|||||||
}
|
}
|
||||||
val tempTableNameGenerator =
|
val tempTableNameGenerator =
|
||||||
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
|
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
|
||||||
|
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
|
||||||
val writer =
|
val writer =
|
||||||
SnowflakeWriter(
|
SnowflakeWriter(
|
||||||
names = catalog,
|
catalog = catalog,
|
||||||
stateGatherer = stateGatherer,
|
stateGatherer = stateGatherer,
|
||||||
streamStateStore = mockk(),
|
streamStateStore = streamStateStore,
|
||||||
snowflakeClient = snowflakeClient,
|
snowflakeClient = snowflakeClient,
|
||||||
tempTableNameGenerator = tempTableNameGenerator,
|
tempTableNameGenerator = tempTableNameGenerator,
|
||||||
snowflakeConfiguration = mockk(relaxed = true),
|
snowflakeConfiguration =
|
||||||
|
mockk(relaxed = true) {
|
||||||
|
every { internalTableSchema } returns "internal_schema"
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
runBlocking {
|
runBlocking {
|
||||||
@@ -203,169 +279,126 @@ internal class SnowflakeWriterTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testSetupWithNamespaceCreationFailure() {
|
fun testCreateStreamLoaderNamespaceLegacy() {
|
||||||
val tableName = TableName(namespace = "test-namespace", name = "test-name")
|
val namespace = "test-namespace"
|
||||||
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
|
val name = "test-name"
|
||||||
val stream = mockk<DestinationStream>()
|
val tableName = TableName(namespace = namespace, name = name)
|
||||||
val tableInfo =
|
val tempTableName = TableName(namespace = namespace, name = "${name}-temp")
|
||||||
TableNameInfo(
|
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
|
||||||
tableNames = tableNames,
|
|
||||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
|
||||||
)
|
|
||||||
val catalog = TableCatalog(mapOf(stream to tableInfo))
|
|
||||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>()
|
|
||||||
val stateGatherer = mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>>()
|
|
||||||
val writer =
|
|
||||||
SnowflakeWriter(
|
|
||||||
names = catalog,
|
|
||||||
stateGatherer = stateGatherer,
|
|
||||||
streamStateStore = mockk(),
|
|
||||||
snowflakeClient = snowflakeClient,
|
|
||||||
tempTableNameGenerator = mockk(),
|
|
||||||
snowflakeConfiguration = mockk(),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Simulate network failure during namespace creation
|
|
||||||
coEvery {
|
|
||||||
snowflakeClient.createNamespace(tableName.namespace.toSnowflakeCompatibleName())
|
|
||||||
} throws RuntimeException("Network connection failed")
|
|
||||||
|
|
||||||
assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } }
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testSetupWithInitialStatusGatheringFailure() {
|
|
||||||
val tableName = TableName(namespace = "test-namespace", name = "test-name")
|
|
||||||
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
|
|
||||||
val stream = mockk<DestinationStream>()
|
|
||||||
val tableInfo =
|
|
||||||
TableNameInfo(
|
|
||||||
tableNames = tableNames,
|
|
||||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
|
||||||
)
|
|
||||||
val catalog = TableCatalog(mapOf(stream to tableInfo))
|
|
||||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
|
||||||
val stateGatherer = mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>>()
|
|
||||||
val writer =
|
|
||||||
SnowflakeWriter(
|
|
||||||
names = catalog,
|
|
||||||
stateGatherer = stateGatherer,
|
|
||||||
streamStateStore = mockk(),
|
|
||||||
snowflakeClient = snowflakeClient,
|
|
||||||
tempTableNameGenerator = mockk(),
|
|
||||||
snowflakeConfiguration = mockk(),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Simulate failure while gathering initial status
|
|
||||||
coEvery { stateGatherer.gatherInitialStatus(catalog) } throws
|
|
||||||
RuntimeException("Failed to query table status")
|
|
||||||
|
|
||||||
assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } }
|
|
||||||
|
|
||||||
// Verify namespace creation was still attempted
|
|
||||||
coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) }
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testCreateStreamLoaderWithMissingInitialStatus() {
|
|
||||||
val tableName = TableName(namespace = "test-namespace", name = "test-name")
|
|
||||||
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
|
|
||||||
val stream =
|
val stream =
|
||||||
mockk<DestinationStream> {
|
mockk<DestinationStream> {
|
||||||
every { minimumGenerationId } returns 0L
|
every { minimumGenerationId } returns 0L
|
||||||
every { generationId } returns 0L
|
every { generationId } returns 0L
|
||||||
|
every { importType } returns Append
|
||||||
|
every { tableSchema } returns
|
||||||
|
StreamTableSchema(
|
||||||
|
tableNames = tableNames,
|
||||||
|
columnSchema =
|
||||||
|
ColumnSchema(
|
||||||
|
inputToFinalColumnNames = emptyMap(),
|
||||||
|
finalSchema = emptyMap(),
|
||||||
|
inputSchema = emptyMap()
|
||||||
|
),
|
||||||
|
importType = Append
|
||||||
|
)
|
||||||
|
every { mappedDescriptor } returns
|
||||||
|
DestinationStream.Descriptor(
|
||||||
|
namespace = tableName.namespace,
|
||||||
|
name = tableName.name
|
||||||
|
)
|
||||||
}
|
}
|
||||||
val missingStream =
|
val catalog = DestinationCatalog(listOf(stream))
|
||||||
mockk<DestinationStream> {
|
|
||||||
every { minimumGenerationId } returns 0L
|
|
||||||
every { generationId } returns 0L
|
|
||||||
}
|
|
||||||
val tableInfo =
|
|
||||||
TableNameInfo(
|
|
||||||
tableNames = tableNames,
|
|
||||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
|
||||||
)
|
|
||||||
val catalog = TableCatalog(mapOf(stream to tableInfo))
|
|
||||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||||
val stateGatherer =
|
val stateGatherer =
|
||||||
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
|
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
|
||||||
coEvery { gatherInitialStatus(catalog) } returns
|
coEvery { gatherInitialStatus() } returns
|
||||||
mapOf(
|
mapOf(
|
||||||
stream to
|
stream to
|
||||||
DirectLoadInitialStatus(
|
DirectLoadInitialStatus(
|
||||||
realTable = DirectLoadTableStatus(false),
|
realTable = DirectLoadTableStatus(false),
|
||||||
tempTable = null,
|
tempTable = null
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
|
||||||
val writer =
|
val writer =
|
||||||
SnowflakeWriter(
|
SnowflakeWriter(
|
||||||
names = catalog,
|
catalog = catalog,
|
||||||
stateGatherer = stateGatherer,
|
stateGatherer = stateGatherer,
|
||||||
streamStateStore = mockk(),
|
streamStateStore = streamStateStore,
|
||||||
snowflakeClient = snowflakeClient,
|
snowflakeClient = snowflakeClient,
|
||||||
tempTableNameGenerator = mockk(),
|
tempTableNameGenerator = mockk(),
|
||||||
snowflakeConfiguration = mockk(relaxed = true),
|
snowflakeConfiguration =
|
||||||
|
mockk(relaxed = true) {
|
||||||
|
every { legacyRawTablesOnly } returns true
|
||||||
|
every { internalTableSchema } returns "internal_schema"
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
runBlocking {
|
runBlocking { writer.setup() }
|
||||||
writer.setup()
|
|
||||||
// Try to create loader for a stream that wasn't in initial status
|
coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) }
|
||||||
assertThrows(NullPointerException::class.java) {
|
}
|
||||||
writer.createStreamLoader(missingStream)
|
|
||||||
|
@Test
|
||||||
|
fun testCreateStreamLoaderNamespaceNonLegacy() {
|
||||||
|
val namespace = "test-namespace"
|
||||||
|
val name = "test-name"
|
||||||
|
val tableName = TableName(namespace = namespace, name = name)
|
||||||
|
val tempTableName = TableName(namespace = namespace, name = "${name}-temp")
|
||||||
|
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
|
||||||
|
val stream =
|
||||||
|
mockk<DestinationStream> {
|
||||||
|
every { minimumGenerationId } returns 0L
|
||||||
|
every { generationId } returns 0L
|
||||||
|
every { importType } returns Append
|
||||||
|
every { tableSchema } returns
|
||||||
|
StreamTableSchema(
|
||||||
|
tableNames = tableNames,
|
||||||
|
columnSchema =
|
||||||
|
ColumnSchema(
|
||||||
|
inputToFinalColumnNames = emptyMap(),
|
||||||
|
finalSchema = emptyMap(),
|
||||||
|
inputSchema = emptyMap()
|
||||||
|
),
|
||||||
|
importType = Append
|
||||||
|
)
|
||||||
|
every { mappedDescriptor } returns
|
||||||
|
DestinationStream.Descriptor(
|
||||||
|
namespace = tableName.namespace,
|
||||||
|
name = tableName.name
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
val catalog = DestinationCatalog(listOf(stream))
|
||||||
}
|
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||||
|
val stateGatherer =
|
||||||
@Test
|
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
|
||||||
fun testCreateStreamLoaderWithNullFinalTableName() {
|
coEvery { gatherInitialStatus() } returns
|
||||||
// TableNames constructor throws IllegalStateException when both names are null
|
mapOf(
|
||||||
assertThrows(IllegalStateException::class.java) {
|
stream to
|
||||||
TableNames(rawTableName = null, finalTableName = null)
|
DirectLoadInitialStatus(
|
||||||
}
|
realTable = DirectLoadTableStatus(false),
|
||||||
}
|
tempTable = null
|
||||||
|
)
|
||||||
@Test
|
)
|
||||||
fun testSetupWithMultipleNamespaceFailuresPartial() {
|
}
|
||||||
val tableName1 = TableName(namespace = "namespace1", name = "table1")
|
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
|
||||||
val tableName2 = TableName(namespace = "namespace2", name = "table2")
|
|
||||||
val tableNames1 = TableNames(rawTableName = null, finalTableName = tableName1)
|
|
||||||
val tableNames2 = TableNames(rawTableName = null, finalTableName = tableName2)
|
|
||||||
val stream1 = mockk<DestinationStream>()
|
|
||||||
val stream2 = mockk<DestinationStream>()
|
|
||||||
val tableInfo1 =
|
|
||||||
TableNameInfo(
|
|
||||||
tableNames = tableNames1,
|
|
||||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
|
||||||
)
|
|
||||||
val tableInfo2 =
|
|
||||||
TableNameInfo(
|
|
||||||
tableNames = tableNames2,
|
|
||||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
|
||||||
)
|
|
||||||
val catalog = TableCatalog(mapOf(stream1 to tableInfo1, stream2 to tableInfo2))
|
|
||||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>()
|
|
||||||
val stateGatherer = mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>>()
|
|
||||||
val writer =
|
val writer =
|
||||||
SnowflakeWriter(
|
SnowflakeWriter(
|
||||||
names = catalog,
|
catalog = catalog,
|
||||||
stateGatherer = stateGatherer,
|
stateGatherer = stateGatherer,
|
||||||
streamStateStore = mockk(),
|
streamStateStore = streamStateStore,
|
||||||
snowflakeClient = snowflakeClient,
|
snowflakeClient = snowflakeClient,
|
||||||
tempTableNameGenerator = mockk(),
|
tempTableNameGenerator = mockk(),
|
||||||
snowflakeConfiguration = mockk(),
|
snowflakeConfiguration =
|
||||||
|
mockk(relaxed = true) {
|
||||||
|
every { legacyRawTablesOnly } returns false
|
||||||
|
every { internalTableSchema } returns "internal_schema"
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
// First namespace succeeds, second fails (namespaces are uppercased by
|
runBlocking { writer.setup() }
|
||||||
// toSnowflakeCompatibleName)
|
|
||||||
coEvery { snowflakeClient.createNamespace("namespace1") } returns Unit
|
|
||||||
coEvery { snowflakeClient.createNamespace("namespace2") } throws
|
|
||||||
RuntimeException("Connection timeout")
|
|
||||||
|
|
||||||
assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } }
|
coVerify(exactly = 1) { snowflakeClient.createNamespace(namespace) }
|
||||||
|
|
||||||
// Verify both namespace creations were attempted
|
|
||||||
coVerify(exactly = 1) { snowflakeClient.createNamespace("namespace1") }
|
|
||||||
coVerify(exactly = 1) { snowflakeClient.createNamespace("namespace2") }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,218 +4,220 @@
|
|||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.write.load
|
package io.airbyte.integrations.destination.snowflake.write.load
|
||||||
|
|
||||||
|
import io.airbyte.cdk.load.component.ColumnType
|
||||||
import io.airbyte.cdk.load.data.AirbyteValue
|
import io.airbyte.cdk.load.data.AirbyteValue
|
||||||
|
import io.airbyte.cdk.load.data.FieldType
|
||||||
import io.airbyte.cdk.load.data.IntegerValue
|
import io.airbyte.cdk.load.data.IntegerValue
|
||||||
import io.airbyte.cdk.load.data.NullValue
|
import io.airbyte.cdk.load.data.NullValue
|
||||||
|
import io.airbyte.cdk.load.data.StringType
|
||||||
import io.airbyte.cdk.load.data.StringValue
|
import io.airbyte.cdk.load.data.StringValue
|
||||||
import io.airbyte.cdk.load.message.Meta
|
import io.airbyte.cdk.load.message.Meta
|
||||||
import io.airbyte.cdk.load.table.TableName
|
import io.airbyte.cdk.load.schema.model.ColumnSchema
|
||||||
|
import io.airbyte.cdk.load.schema.model.TableName
|
||||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||||
|
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
|
||||||
import io.mockk.coVerify
|
import io.mockk.coVerify
|
||||||
import io.mockk.every
|
import io.mockk.every
|
||||||
import io.mockk.mockk
|
import io.mockk.mockk
|
||||||
import java.io.BufferedReader
|
import java.io.BufferedReader
|
||||||
import java.io.File
|
|
||||||
import java.io.InputStreamReader
|
import java.io.InputStreamReader
|
||||||
import java.util.zip.GZIPInputStream
|
import java.util.zip.GZIPInputStream
|
||||||
import kotlin.io.path.exists
|
import kotlin.io.path.exists
|
||||||
import kotlinx.coroutines.runBlocking
|
import kotlinx.coroutines.runBlocking
|
||||||
import org.junit.jupiter.api.Assertions.assertEquals
|
import org.junit.jupiter.api.Assertions.assertEquals
|
||||||
|
import org.junit.jupiter.api.Assertions.assertNotNull
|
||||||
import org.junit.jupiter.api.BeforeEach
|
import org.junit.jupiter.api.BeforeEach
|
||||||
import org.junit.jupiter.api.Test
|
import org.junit.jupiter.api.Test
|
||||||
|
|
||||||
internal class SnowflakeInsertBufferTest {
|
internal class SnowflakeInsertBufferTest {
|
||||||
|
|
||||||
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
|
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
|
||||||
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
|
private lateinit var columnManager: SnowflakeColumnManager
|
||||||
|
private lateinit var columnSchema: ColumnSchema
|
||||||
|
private lateinit var snowflakeRecordFormatter: SnowflakeRecordFormatter
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
fun setUp() {
|
fun setUp() {
|
||||||
snowflakeConfiguration = mockk(relaxed = true)
|
snowflakeConfiguration = mockk(relaxed = true)
|
||||||
snowflakeColumnUtils = mockk(relaxed = true)
|
snowflakeRecordFormatter = SnowflakeSchemaRecordFormatter()
|
||||||
|
columnManager =
|
||||||
|
mockk(relaxed = true) {
|
||||||
|
every { getMetaColumns() } returns
|
||||||
|
linkedMapOf(
|
||||||
|
"_AIRBYTE_RAW_ID" to ColumnType("VARCHAR", false),
|
||||||
|
"_AIRBYTE_EXTRACTED_AT" to ColumnType("TIMESTAMP_TZ", false),
|
||||||
|
"_AIRBYTE_META" to ColumnType("VARIANT", false),
|
||||||
|
"_AIRBYTE_GENERATION_ID" to ColumnType("NUMBER", true)
|
||||||
|
)
|
||||||
|
every { getTableColumnNames(any()) } returns
|
||||||
|
listOf(
|
||||||
|
"_AIRBYTE_RAW_ID",
|
||||||
|
"_AIRBYTE_EXTRACTED_AT",
|
||||||
|
"_AIRBYTE_META",
|
||||||
|
"_AIRBYTE_GENERATION_ID",
|
||||||
|
"columnName"
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testAccumulate() {
|
fun testAccumulate() {
|
||||||
val tableName = mockk<TableName>(relaxed = true)
|
val tableName = TableName(namespace = "test", name = "table")
|
||||||
val column = "columnName"
|
val column = "columnName"
|
||||||
val columns = linkedMapOf(column to "NUMBER(38,0)")
|
columnSchema =
|
||||||
|
ColumnSchema(
|
||||||
|
inputToFinalColumnNames = mapOf(column to column.uppercase()),
|
||||||
|
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
|
||||||
|
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
|
||||||
|
)
|
||||||
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||||
val record = createRecord(column)
|
val record = createRecord(column)
|
||||||
val buffer =
|
val buffer =
|
||||||
SnowflakeInsertBuffer(
|
SnowflakeInsertBuffer(
|
||||||
tableName = tableName,
|
tableName = tableName,
|
||||||
columns = columns,
|
|
||||||
snowflakeClient = snowflakeAirbyteClient,
|
snowflakeClient = snowflakeAirbyteClient,
|
||||||
snowflakeConfiguration = snowflakeConfiguration,
|
snowflakeConfiguration = snowflakeConfiguration,
|
||||||
flushLimit = 1,
|
columnSchema = columnSchema,
|
||||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
columnManager = columnManager,
|
||||||
|
snowflakeRecordFormatter = snowflakeRecordFormatter,
|
||||||
)
|
)
|
||||||
|
assertEquals(0, buffer.recordCount)
|
||||||
buffer.accumulate(record)
|
runBlocking { buffer.accumulate(record) }
|
||||||
|
|
||||||
assertEquals(true, buffer.csvFilePath?.exists())
|
|
||||||
assertEquals(1, buffer.recordCount)
|
assertEquals(1, buffer.recordCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testAccumulateRaw() {
|
fun testFlushToStaging() {
|
||||||
val tableName = mockk<TableName>(relaxed = true)
|
val tableName = TableName(namespace = "test", name = "table")
|
||||||
val column = "columnName"
|
val column = "columnName"
|
||||||
val columns = linkedMapOf(column to "NUMBER(38,0)")
|
columnSchema =
|
||||||
|
ColumnSchema(
|
||||||
|
inputToFinalColumnNames = mapOf(column to column.uppercase()),
|
||||||
|
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
|
||||||
|
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
|
||||||
|
)
|
||||||
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||||
val record = createRecord(column)
|
val record = createRecord(column)
|
||||||
val buffer =
|
val buffer =
|
||||||
SnowflakeInsertBuffer(
|
SnowflakeInsertBuffer(
|
||||||
tableName = tableName,
|
tableName = tableName,
|
||||||
columns = columns,
|
|
||||||
snowflakeClient = snowflakeAirbyteClient,
|
snowflakeClient = snowflakeAirbyteClient,
|
||||||
snowflakeConfiguration = snowflakeConfiguration,
|
snowflakeConfiguration = snowflakeConfiguration,
|
||||||
|
columnSchema = columnSchema,
|
||||||
|
columnManager = columnManager,
|
||||||
|
snowflakeRecordFormatter = snowflakeRecordFormatter,
|
||||||
flushLimit = 1,
|
flushLimit = 1,
|
||||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
|
||||||
)
|
)
|
||||||
|
val expectedColumnNames =
|
||||||
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
|
listOf(
|
||||||
|
"_AIRBYTE_RAW_ID",
|
||||||
buffer.accumulate(record)
|
"_AIRBYTE_EXTRACTED_AT",
|
||||||
|
"_AIRBYTE_META",
|
||||||
assertEquals(true, buffer.csvFilePath?.exists())
|
"_AIRBYTE_GENERATION_ID",
|
||||||
assertEquals(1, buffer.recordCount)
|
"columnName"
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testFlush() {
|
|
||||||
val tableName = mockk<TableName>(relaxed = true)
|
|
||||||
val column = "columnName"
|
|
||||||
val columns = linkedMapOf(column to "NUMBER(38,0)")
|
|
||||||
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
|
||||||
val record = createRecord(column)
|
|
||||||
val buffer =
|
|
||||||
SnowflakeInsertBuffer(
|
|
||||||
tableName = tableName,
|
|
||||||
columns = columns,
|
|
||||||
snowflakeClient = snowflakeAirbyteClient,
|
|
||||||
snowflakeConfiguration = snowflakeConfiguration,
|
|
||||||
flushLimit = 1,
|
|
||||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
runBlocking {
|
runBlocking {
|
||||||
buffer.accumulate(record)
|
buffer.accumulate(record)
|
||||||
buffer.flush()
|
buffer.flush()
|
||||||
}
|
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
|
||||||
|
coVerify(exactly = 1) {
|
||||||
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
|
snowflakeAirbyteClient.copyFromStage(tableName, any(), expectedColumnNames)
|
||||||
coVerify(exactly = 1) {
|
}
|
||||||
snowflakeAirbyteClient.copyFromStage(
|
|
||||||
tableName,
|
|
||||||
match { it.endsWith("$CSV_FILE_EXTENSION$FILE_SUFFIX") }
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testFlushRaw() {
|
fun testFlushToNoStaging() {
|
||||||
val tableName = mockk<TableName>(relaxed = true)
|
val tableName = TableName(namespace = "test", name = "table")
|
||||||
val column = "columnName"
|
val column = "columnName"
|
||||||
val columns = linkedMapOf(column to "NUMBER(38,0)")
|
columnSchema =
|
||||||
|
ColumnSchema(
|
||||||
|
inputToFinalColumnNames = mapOf(column to column.uppercase()),
|
||||||
|
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
|
||||||
|
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
|
||||||
|
)
|
||||||
|
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||||
|
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
|
||||||
|
val record = createRecord(column)
|
||||||
|
val buffer =
|
||||||
|
SnowflakeInsertBuffer(
|
||||||
|
tableName = tableName,
|
||||||
|
snowflakeClient = snowflakeAirbyteClient,
|
||||||
|
snowflakeConfiguration = snowflakeConfiguration,
|
||||||
|
columnSchema = columnSchema,
|
||||||
|
columnManager = columnManager,
|
||||||
|
snowflakeRecordFormatter = snowflakeRecordFormatter,
|
||||||
|
flushLimit = 1,
|
||||||
|
)
|
||||||
|
val expectedColumnNames =
|
||||||
|
listOf(
|
||||||
|
"_AIRBYTE_RAW_ID",
|
||||||
|
"_AIRBYTE_EXTRACTED_AT",
|
||||||
|
"_AIRBYTE_META",
|
||||||
|
"_AIRBYTE_GENERATION_ID",
|
||||||
|
"columnName"
|
||||||
|
)
|
||||||
|
runBlocking {
|
||||||
|
buffer.accumulate(record)
|
||||||
|
buffer.flush()
|
||||||
|
// In legacy raw mode, it still uses staging
|
||||||
|
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
|
||||||
|
coVerify(exactly = 1) {
|
||||||
|
snowflakeAirbyteClient.copyFromStage(tableName, any(), expectedColumnNames)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testFileCreation() {
|
||||||
|
val tableName = TableName(namespace = "test", name = "table")
|
||||||
|
val column = "columnName"
|
||||||
|
columnSchema =
|
||||||
|
ColumnSchema(
|
||||||
|
inputToFinalColumnNames = mapOf(column to column.uppercase()),
|
||||||
|
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
|
||||||
|
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
|
||||||
|
)
|
||||||
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||||
val record = createRecord(column)
|
val record = createRecord(column)
|
||||||
val buffer =
|
val buffer =
|
||||||
SnowflakeInsertBuffer(
|
SnowflakeInsertBuffer(
|
||||||
tableName = tableName,
|
tableName = tableName,
|
||||||
columns = columns,
|
|
||||||
snowflakeClient = snowflakeAirbyteClient,
|
snowflakeClient = snowflakeAirbyteClient,
|
||||||
snowflakeConfiguration = snowflakeConfiguration,
|
snowflakeConfiguration = snowflakeConfiguration,
|
||||||
|
columnSchema = columnSchema,
|
||||||
|
columnManager = columnManager,
|
||||||
|
snowflakeRecordFormatter = snowflakeRecordFormatter,
|
||||||
flushLimit = 1,
|
flushLimit = 1,
|
||||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
|
|
||||||
|
|
||||||
runBlocking {
|
runBlocking {
|
||||||
buffer.accumulate(record)
|
buffer.accumulate(record)
|
||||||
buffer.flush()
|
// The csvFilePath is internal, we can access it for testing
|
||||||
}
|
val filepath = buffer.csvFilePath
|
||||||
|
assertNotNull(filepath)
|
||||||
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
|
val file = filepath!!.toFile()
|
||||||
coVerify(exactly = 1) {
|
assert(file.exists())
|
||||||
snowflakeAirbyteClient.copyFromStage(
|
// Close the writer to ensure all data is flushed
|
||||||
tableName,
|
|
||||||
match { it.endsWith("$CSV_FILE_EXTENSION$FILE_SUFFIX") }
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testMissingFields() {
|
|
||||||
val tableName = mockk<TableName>(relaxed = true)
|
|
||||||
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
|
||||||
val record = createRecord("COLUMN1")
|
|
||||||
val buffer =
|
|
||||||
SnowflakeInsertBuffer(
|
|
||||||
tableName = tableName,
|
|
||||||
columns = linkedMapOf("COLUMN1" to "NUMBER(38,0)", "COLUMN2" to "NUMBER(38,0)"),
|
|
||||||
snowflakeClient = snowflakeAirbyteClient,
|
|
||||||
snowflakeConfiguration = snowflakeConfiguration,
|
|
||||||
flushLimit = 1,
|
|
||||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
|
||||||
)
|
|
||||||
|
|
||||||
runBlocking {
|
|
||||||
buffer.accumulate(record)
|
|
||||||
buffer.csvWriter?.flush()
|
|
||||||
buffer.csvWriter?.close()
|
buffer.csvWriter?.close()
|
||||||
assertEquals(
|
val lines = mutableListOf<String>()
|
||||||
"test-value$CSV_FIELD_SEPARATOR$CSV_LINE_DELIMITER",
|
GZIPInputStream(file.inputStream()).use { gzip ->
|
||||||
readFromCsvFile(buffer.csvFilePath!!.toFile())
|
BufferedReader(InputStreamReader(gzip)).use { bufferedReader ->
|
||||||
)
|
bufferedReader.forEachLine { line -> lines.add(line) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assertEquals(1, lines.size)
|
||||||
|
file.delete()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
private fun createRecord(column: String): Map<String, AirbyteValue> {
|
||||||
fun testMissingFieldsRaw() {
|
return mapOf(
|
||||||
val tableName = mockk<TableName>(relaxed = true)
|
column to IntegerValue(value = 42),
|
||||||
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
Meta.COLUMN_NAME_AB_GENERATION_ID to NullValue,
|
||||||
val record = createRecord("COLUMN1")
|
Meta.COLUMN_NAME_AB_RAW_ID to StringValue("raw-id-1"),
|
||||||
val buffer =
|
Meta.COLUMN_NAME_AB_EXTRACTED_AT to IntegerValue(1234567890),
|
||||||
SnowflakeInsertBuffer(
|
Meta.COLUMN_NAME_AB_META to StringValue("meta-data-foo"),
|
||||||
tableName = tableName,
|
|
||||||
columns = linkedMapOf("COLUMN1" to "NUMBER(38,0)", "COLUMN2" to "NUMBER(38,0)"),
|
|
||||||
snowflakeClient = snowflakeAirbyteClient,
|
|
||||||
snowflakeConfiguration = snowflakeConfiguration,
|
|
||||||
flushLimit = 1,
|
|
||||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
|
||||||
)
|
|
||||||
|
|
||||||
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
|
|
||||||
|
|
||||||
runBlocking {
|
|
||||||
buffer.accumulate(record)
|
|
||||||
buffer.csvWriter?.flush()
|
|
||||||
buffer.csvWriter?.close()
|
|
||||||
assertEquals(
|
|
||||||
"test-value$CSV_FIELD_SEPARATOR$CSV_LINE_DELIMITER",
|
|
||||||
readFromCsvFile(buffer.csvFilePath!!.toFile())
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun readFromCsvFile(file: File) =
|
|
||||||
GZIPInputStream(file.inputStream()).use { input ->
|
|
||||||
val reader = BufferedReader(InputStreamReader(input))
|
|
||||||
reader.readText()
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun createRecord(columnName: String) =
|
|
||||||
mapOf(
|
|
||||||
columnName to AirbyteValue.from("test-value"),
|
|
||||||
Meta.COLUMN_NAME_AB_EXTRACTED_AT to IntegerValue(System.currentTimeMillis()),
|
|
||||||
Meta.COLUMN_NAME_AB_RAW_ID to StringValue("raw-id"),
|
|
||||||
Meta.COLUMN_NAME_AB_GENERATION_ID to IntegerValue(1223),
|
|
||||||
Meta.COLUMN_NAME_AB_META to StringValue("{\"changes\":[],\"syncId\":43}"),
|
|
||||||
"${columnName}Null" to NullValue
|
|
||||||
)
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,12 +15,8 @@ import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT
|
|||||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
|
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
|
||||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
|
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
|
||||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA
|
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
import io.airbyte.cdk.load.schema.model.ColumnSchema
|
||||||
import io.mockk.every
|
|
||||||
import io.mockk.mockk
|
|
||||||
import kotlin.collections.plus
|
|
||||||
import org.junit.jupiter.api.Assertions.assertEquals
|
import org.junit.jupiter.api.Assertions.assertEquals
|
||||||
import org.junit.jupiter.api.BeforeEach
|
|
||||||
import org.junit.jupiter.api.Test
|
import org.junit.jupiter.api.Test
|
||||||
|
|
||||||
private val AIRBYTE_COLUMN_TYPES_MAP =
|
private val AIRBYTE_COLUMN_TYPES_MAP =
|
||||||
@@ -58,28 +54,16 @@ private fun createExpected(
|
|||||||
|
|
||||||
internal class SnowflakeRawRecordFormatterTest {
|
internal class SnowflakeRawRecordFormatterTest {
|
||||||
|
|
||||||
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
|
|
||||||
|
|
||||||
@BeforeEach
|
|
||||||
fun setup() {
|
|
||||||
snowflakeColumnUtils = mockk {
|
|
||||||
every { getFormattedDefaultColumnNames(any()) } returns
|
|
||||||
AIRBYTE_COLUMN_TYPES_MAP.keys.toList()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testFormatting() {
|
fun testFormatting() {
|
||||||
val columnName = "test-column-name"
|
val columnName = "test-column-name"
|
||||||
val columnValue = "test-column-value"
|
val columnValue = "test-column-value"
|
||||||
val columns = AIRBYTE_COLUMN_TYPES_MAP
|
val columns = AIRBYTE_COLUMN_TYPES_MAP
|
||||||
val record = createRecord(columnName = columnName, columnValue = columnValue)
|
val record = createRecord(columnName = columnName, columnValue = columnValue)
|
||||||
val formatter =
|
val formatter = SnowflakeRawRecordFormatter()
|
||||||
SnowflakeRawRecordFormatter(
|
// RawRecordFormatter doesn't use columnSchema but still needs one per interface
|
||||||
columns = AIRBYTE_COLUMN_TYPES_MAP,
|
val dummyColumnSchema = ColumnSchema(emptyMap(), emptyMap(), emptyMap())
|
||||||
snowflakeColumnUtils = snowflakeColumnUtils
|
val formattedValue = formatter.format(record, dummyColumnSchema)
|
||||||
)
|
|
||||||
val formattedValue = formatter.format(record)
|
|
||||||
val expectedValue =
|
val expectedValue =
|
||||||
createExpected(
|
createExpected(
|
||||||
record = record,
|
record = record,
|
||||||
@@ -93,33 +77,28 @@ internal class SnowflakeRawRecordFormatterTest {
|
|||||||
fun testFormattingMigratedFromPreviousVersion() {
|
fun testFormattingMigratedFromPreviousVersion() {
|
||||||
val columnName = "test-column-name"
|
val columnName = "test-column-name"
|
||||||
val columnValue = "test-column-value"
|
val columnValue = "test-column-value"
|
||||||
val columnsMap =
|
|
||||||
linkedMapOf(
|
|
||||||
COLUMN_NAME_AB_EXTRACTED_AT to "TIMESTAMP_TZ(9)",
|
|
||||||
COLUMN_NAME_AB_LOADED_AT to "TIMESTAMP_TZ(9)",
|
|
||||||
COLUMN_NAME_AB_META to "VARIANT",
|
|
||||||
COLUMN_NAME_DATA to "VARIANT",
|
|
||||||
COLUMN_NAME_AB_RAW_ID to "VARCHAR(16777216)",
|
|
||||||
COLUMN_NAME_AB_GENERATION_ID to "NUMBER(38,0)",
|
|
||||||
)
|
|
||||||
val record = createRecord(columnName = columnName, columnValue = columnValue)
|
val record = createRecord(columnName = columnName, columnValue = columnValue)
|
||||||
val formatter =
|
val formatter = SnowflakeRawRecordFormatter()
|
||||||
SnowflakeRawRecordFormatter(
|
// RawRecordFormatter doesn't use columnSchema but still needs one per interface
|
||||||
columns = columnsMap,
|
val dummyColumnSchema = ColumnSchema(emptyMap(), emptyMap(), emptyMap())
|
||||||
snowflakeColumnUtils = snowflakeColumnUtils
|
val formattedValue = formatter.format(record, dummyColumnSchema)
|
||||||
)
|
|
||||||
val formattedValue = formatter.format(record)
|
// The formatter outputs in a fixed order regardless of input column order:
|
||||||
|
// 1. AB_RAW_ID
|
||||||
|
// 2. AB_EXTRACTED_AT
|
||||||
|
// 3. AB_META
|
||||||
|
// 4. AB_GENERATION_ID
|
||||||
|
// 5. AB_LOADED_AT
|
||||||
|
// 6. DATA (JSON with remaining columns)
|
||||||
val expectedValue =
|
val expectedValue =
|
||||||
createExpected(
|
listOf(
|
||||||
record = record,
|
record[COLUMN_NAME_AB_RAW_ID]!!.toCsvValue(),
|
||||||
columns = columnsMap,
|
record[COLUMN_NAME_AB_EXTRACTED_AT]!!.toCsvValue(),
|
||||||
airbyteColumns = columnsMap.keys.toList(),
|
record[COLUMN_NAME_AB_META]!!.toCsvValue(),
|
||||||
)
|
record[COLUMN_NAME_AB_GENERATION_ID]!!.toCsvValue(),
|
||||||
.toMutableList()
|
record[COLUMN_NAME_AB_LOADED_AT]!!.toCsvValue(),
|
||||||
expectedValue.add(
|
"{\"$columnName\":\"$columnValue\"}"
|
||||||
columnsMap.keys.indexOf(COLUMN_NAME_DATA),
|
)
|
||||||
"{\"$columnName\":\"$columnValue\"}"
|
|
||||||
)
|
|
||||||
assertEquals(expectedValue, formattedValue)
|
assertEquals(expectedValue, formattedValue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,66 +4,58 @@
|
|||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.write.load
|
package io.airbyte.integrations.destination.snowflake.write.load
|
||||||
|
|
||||||
|
import io.airbyte.cdk.load.component.ColumnType
|
||||||
import io.airbyte.cdk.load.data.AirbyteValue
|
import io.airbyte.cdk.load.data.AirbyteValue
|
||||||
|
import io.airbyte.cdk.load.data.FieldType
|
||||||
import io.airbyte.cdk.load.data.IntegerValue
|
import io.airbyte.cdk.load.data.IntegerValue
|
||||||
import io.airbyte.cdk.load.data.NullValue
|
import io.airbyte.cdk.load.data.NullValue
|
||||||
|
import io.airbyte.cdk.load.data.StringType
|
||||||
import io.airbyte.cdk.load.data.StringValue
|
import io.airbyte.cdk.load.data.StringValue
|
||||||
import io.airbyte.cdk.load.data.csv.toCsvValue
|
import io.airbyte.cdk.load.data.csv.toCsvValue
|
||||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
|
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
|
||||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
|
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
|
||||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
|
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
|
||||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
|
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
|
||||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
import io.airbyte.cdk.load.schema.model.ColumnSchema
|
||||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||||
import io.mockk.every
|
|
||||||
import io.mockk.mockk
|
|
||||||
import java.util.AbstractMap
|
|
||||||
import kotlin.collections.component1
|
|
||||||
import kotlin.collections.component2
|
|
||||||
import org.junit.jupiter.api.Assertions.assertEquals
|
import org.junit.jupiter.api.Assertions.assertEquals
|
||||||
import org.junit.jupiter.api.BeforeEach
|
|
||||||
import org.junit.jupiter.api.Test
|
import org.junit.jupiter.api.Test
|
||||||
|
|
||||||
private val AIRBYTE_COLUMN_TYPES_MAP =
|
|
||||||
linkedMapOf(
|
|
||||||
COLUMN_NAME_AB_RAW_ID to "VARCHAR(16777216)",
|
|
||||||
COLUMN_NAME_AB_EXTRACTED_AT to "TIMESTAMP_TZ(9)",
|
|
||||||
COLUMN_NAME_AB_META to "VARIANT",
|
|
||||||
COLUMN_NAME_AB_GENERATION_ID to "NUMBER(38,0)",
|
|
||||||
)
|
|
||||||
.mapKeys { it.key.toSnowflakeCompatibleName() }
|
|
||||||
|
|
||||||
internal class SnowflakeSchemaRecordFormatterTest {
|
internal class SnowflakeSchemaRecordFormatterTest {
|
||||||
|
|
||||||
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
|
private fun createColumnSchema(userColumns: Map<String, String>): ColumnSchema {
|
||||||
|
val finalSchema = linkedMapOf<String, ColumnType>()
|
||||||
|
val inputToFinalColumnNames = mutableMapOf<String, String>()
|
||||||
|
val inputSchema = mutableMapOf<String, FieldType>()
|
||||||
|
|
||||||
@BeforeEach
|
// Add user columns
|
||||||
fun setup() {
|
userColumns.forEach { (name, type) ->
|
||||||
snowflakeColumnUtils = mockk {
|
val finalName = name.toSnowflakeCompatibleName()
|
||||||
every { getFormattedDefaultColumnNames(any()) } returns
|
finalSchema[finalName] = ColumnType(type, true)
|
||||||
AIRBYTE_COLUMN_TYPES_MAP.keys.toList()
|
inputToFinalColumnNames[name] = finalName
|
||||||
|
inputSchema[name] = FieldType(StringType, nullable = true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return ColumnSchema(
|
||||||
|
inputToFinalColumnNames = inputToFinalColumnNames,
|
||||||
|
finalSchema = finalSchema,
|
||||||
|
inputSchema = inputSchema
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testFormatting() {
|
fun testFormatting() {
|
||||||
val columnName = "test-column-name"
|
val columnName = "test-column-name"
|
||||||
val columnValue = "test-column-value"
|
val columnValue = "test-column-value"
|
||||||
val columns =
|
val userColumns = mapOf(columnName to "VARCHAR(16777216)")
|
||||||
(AIRBYTE_COLUMN_TYPES_MAP + linkedMapOf(columnName to "VARCHAR(16777216)")).mapKeys {
|
val columnSchema = createColumnSchema(userColumns)
|
||||||
it.key.toSnowflakeCompatibleName()
|
|
||||||
}
|
|
||||||
val record = createRecord(columnName, columnValue)
|
val record = createRecord(columnName, columnValue)
|
||||||
val formatter =
|
val formatter = SnowflakeSchemaRecordFormatter()
|
||||||
SnowflakeSchemaRecordFormatter(
|
val formattedValue = formatter.format(record, columnSchema)
|
||||||
columns = columns as LinkedHashMap<String, String>,
|
|
||||||
snowflakeColumnUtils = snowflakeColumnUtils
|
|
||||||
)
|
|
||||||
val formattedValue = formatter.format(record)
|
|
||||||
val expectedValue =
|
val expectedValue =
|
||||||
createExpected(
|
createExpected(
|
||||||
record = record,
|
record = record,
|
||||||
columns = columns,
|
columnSchema = columnSchema,
|
||||||
)
|
)
|
||||||
assertEquals(expectedValue, formattedValue)
|
assertEquals(expectedValue, formattedValue)
|
||||||
}
|
}
|
||||||
@@ -72,21 +64,15 @@ internal class SnowflakeSchemaRecordFormatterTest {
|
|||||||
fun testFormattingVariant() {
|
fun testFormattingVariant() {
|
||||||
val columnName = "test-column-name"
|
val columnName = "test-column-name"
|
||||||
val columnValue = "{\"test\": \"test-value\"}"
|
val columnValue = "{\"test\": \"test-value\"}"
|
||||||
val columns =
|
val userColumns = mapOf(columnName to "VARIANT")
|
||||||
(AIRBYTE_COLUMN_TYPES_MAP + linkedMapOf(columnName to "VARIANT")).mapKeys {
|
val columnSchema = createColumnSchema(userColumns)
|
||||||
it.key.toSnowflakeCompatibleName()
|
|
||||||
}
|
|
||||||
val record = createRecord(columnName, columnValue)
|
val record = createRecord(columnName, columnValue)
|
||||||
val formatter =
|
val formatter = SnowflakeSchemaRecordFormatter()
|
||||||
SnowflakeSchemaRecordFormatter(
|
val formattedValue = formatter.format(record, columnSchema)
|
||||||
columns = columns as LinkedHashMap<String, String>,
|
|
||||||
snowflakeColumnUtils = snowflakeColumnUtils
|
|
||||||
)
|
|
||||||
val formattedValue = formatter.format(record)
|
|
||||||
val expectedValue =
|
val expectedValue =
|
||||||
createExpected(
|
createExpected(
|
||||||
record = record,
|
record = record,
|
||||||
columns = columns,
|
columnSchema = columnSchema,
|
||||||
)
|
)
|
||||||
assertEquals(expectedValue, formattedValue)
|
assertEquals(expectedValue, formattedValue)
|
||||||
}
|
}
|
||||||
@@ -95,23 +81,16 @@ internal class SnowflakeSchemaRecordFormatterTest {
|
|||||||
fun testFormattingMissingColumn() {
|
fun testFormattingMissingColumn() {
|
||||||
val columnName = "test-column-name"
|
val columnName = "test-column-name"
|
||||||
val columnValue = "test-column-value"
|
val columnValue = "test-column-value"
|
||||||
val columns =
|
val userColumns =
|
||||||
AIRBYTE_COLUMN_TYPES_MAP +
|
mapOf(columnName to "VARCHAR(16777216)", "missing-column" to "VARCHAR(16777216)")
|
||||||
linkedMapOf(
|
val columnSchema = createColumnSchema(userColumns)
|
||||||
columnName to "VARCHAR(16777216)",
|
|
||||||
"missing-column" to "VARCHAR(16777216)"
|
|
||||||
)
|
|
||||||
val record = createRecord(columnName, columnValue)
|
val record = createRecord(columnName, columnValue)
|
||||||
val formatter =
|
val formatter = SnowflakeSchemaRecordFormatter()
|
||||||
SnowflakeSchemaRecordFormatter(
|
val formattedValue = formatter.format(record, columnSchema)
|
||||||
columns = columns as LinkedHashMap<String, String>,
|
|
||||||
snowflakeColumnUtils = snowflakeColumnUtils
|
|
||||||
)
|
|
||||||
val formattedValue = formatter.format(record)
|
|
||||||
val expectedValue =
|
val expectedValue =
|
||||||
createExpected(
|
createExpected(
|
||||||
record = record,
|
record = record,
|
||||||
columns = columns,
|
columnSchema = columnSchema,
|
||||||
filterMissing = false,
|
filterMissing = false,
|
||||||
)
|
)
|
||||||
assertEquals(expectedValue, formattedValue)
|
assertEquals(expectedValue, formattedValue)
|
||||||
@@ -128,16 +107,37 @@ internal class SnowflakeSchemaRecordFormatterTest {
|
|||||||
|
|
||||||
private fun createExpected(
|
private fun createExpected(
|
||||||
record: Map<String, AirbyteValue>,
|
record: Map<String, AirbyteValue>,
|
||||||
columns: Map<String, String>,
|
columnSchema: ColumnSchema,
|
||||||
filterMissing: Boolean = true,
|
filterMissing: Boolean = true,
|
||||||
) =
|
): List<Any> {
|
||||||
record.entries
|
val columns = columnSchema.finalSchema.keys.toList()
|
||||||
.associate { entry -> entry.key.toSnowflakeCompatibleName() to entry.value }
|
val result = mutableListOf<Any>()
|
||||||
.map { entry -> AbstractMap.SimpleEntry(entry.key, entry.value.toCsvValue()) }
|
|
||||||
.sortedBy { entry ->
|
// Add meta columns first in the expected order
|
||||||
if (columns.keys.indexOf(entry.key) > -1) columns.keys.indexOf(entry.key)
|
result.add(record[COLUMN_NAME_AB_RAW_ID]?.toCsvValue() ?: "")
|
||||||
else Int.MAX_VALUE
|
result.add(record[COLUMN_NAME_AB_EXTRACTED_AT]?.toCsvValue() ?: "")
|
||||||
|
result.add(record[COLUMN_NAME_AB_META]?.toCsvValue() ?: "")
|
||||||
|
result.add(record[COLUMN_NAME_AB_GENERATION_ID]?.toCsvValue() ?: "")
|
||||||
|
|
||||||
|
// Add user columns
|
||||||
|
val userColumns =
|
||||||
|
columns.filterNot { col ->
|
||||||
|
listOf(
|
||||||
|
COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName(),
|
||||||
|
COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName(),
|
||||||
|
COLUMN_NAME_AB_META.toSnowflakeCompatibleName(),
|
||||||
|
COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName()
|
||||||
|
)
|
||||||
|
.contains(col)
|
||||||
}
|
}
|
||||||
.filter { (k, _) -> if (filterMissing) columns.contains(k) else true }
|
|
||||||
.map { it.value }
|
userColumns.forEach { columnName ->
|
||||||
|
val value = record[columnName] ?: if (!filterMissing) NullValue else null
|
||||||
|
if (value != null || !filterMissing) {
|
||||||
|
result.add(value?.toCsvValue() ?: "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,45 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package io.airbyte.integrations.destination.snowflake.write.transform
|
|
||||||
|
|
||||||
import io.airbyte.cdk.load.command.DestinationStream
|
|
||||||
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog
|
|
||||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
|
||||||
import io.mockk.every
|
|
||||||
import io.mockk.mockk
|
|
||||||
import org.junit.jupiter.api.Assertions.assertEquals
|
|
||||||
import org.junit.jupiter.api.Test
|
|
||||||
|
|
||||||
internal class SnowflakeColumnNameMapperTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testGetMappedColumnName() {
|
|
||||||
val columnName = "tést-column-name"
|
|
||||||
val expectedName = "test-column-name"
|
|
||||||
val stream = mockk<DestinationStream>()
|
|
||||||
val tableCatalog = mockk<TableCatalog>()
|
|
||||||
val snowflakeConfiguration = mockk<SnowflakeConfiguration>(relaxed = true)
|
|
||||||
|
|
||||||
// Configure the mock to return the expected mapped column name
|
|
||||||
every { tableCatalog.getMappedColumnName(stream, columnName) } returns expectedName
|
|
||||||
|
|
||||||
val mapper = SnowflakeColumnNameMapper(tableCatalog, snowflakeConfiguration)
|
|
||||||
val result = mapper.getMappedColumnName(stream = stream, columnName = columnName)
|
|
||||||
assertEquals(expectedName, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testGetMappedColumnNameRawFormat() {
|
|
||||||
val columnName = "tést-column-name"
|
|
||||||
val stream = mockk<DestinationStream>()
|
|
||||||
val tableCatalog = mockk<TableCatalog>()
|
|
||||||
val snowflakeConfiguration =
|
|
||||||
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns true }
|
|
||||||
|
|
||||||
val mapper = SnowflakeColumnNameMapper(tableCatalog, snowflakeConfiguration)
|
|
||||||
val result = mapper.getMappedColumnName(stream = stream, columnName = columnName)
|
|
||||||
assertEquals(columnName, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -260,6 +260,7 @@ desired namespace.
|
|||||||
|
|
||||||
| Version | Date | Pull Request | Subject |
|
| Version | Date | Pull Request | Subject |
|
||||||
|:----------------|:-----------|:--------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|:----------------|:-----------|:--------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| 4.0.32-rc.1 | 2025-12-18 | [70903](https://github.com/airbytehq/airbyte/pull/70903) | Upgrade to CDK 0.1.91; internal refactoring |
|
||||||
| 4.0.31 | 2025-12-08 | [70442](https://github.com/airbytehq/airbyte/pull/70442) | Write VARIANT values correctly when underlying Airbyte type is a `union` |
|
| 4.0.31 | 2025-12-08 | [70442](https://github.com/airbytehq/airbyte/pull/70442) | Write VARIANT values correctly when underlying Airbyte type is a `union` |
|
||||||
| 4.0.30 | 2025-11-24 | [69842](https://github.com/airbytehq/airbyte/pull/69842) | Update documentation about numeric value handling |
|
| 4.0.30 | 2025-11-24 | [69842](https://github.com/airbytehq/airbyte/pull/69842) | Update documentation about numeric value handling |
|
||||||
| 4.0.29 | 2025-11-14 | [69342](https://github.com/airbytehq/airbyte/pull/69342) | Truncate NumberValues and IntegerValues with excessive precision instead of nullifying them |
|
| 4.0.29 | 2025-11-14 | [69342](https://github.com/airbytehq/airbyte/pull/69342) | Truncate NumberValues and IntegerValues with excessive precision instead of nullifying them |
|
||||||
|
|||||||
Reference in New Issue
Block a user