1
0
mirror of synced 2025-12-19 18:14:56 -05:00

Merge branch 'master' into devin/1765393705-fix-youtube-analytics-job-creation

This commit is contained in:
Aaron ("AJ") Steers
2025-12-18 16:11:00 -08:00
committed by GitHub
1402 changed files with 83973 additions and 10058 deletions

View File

@@ -139,8 +139,8 @@ runs:
CONNECTOR_VERSION_TAG="${{ inputs.tag-override }}"
echo "🏷 Using provided tag override: $CONNECTOR_VERSION_TAG"
elif [[ "${{ inputs.release-type }}" == "pre-release" ]]; then
hash=$(git rev-parse --short=10 HEAD)
CONNECTOR_VERSION_TAG="${CONNECTOR_VERSION}-dev.${hash}"
hash=$(git rev-parse --short=7 HEAD)
CONNECTOR_VERSION_TAG="${CONNECTOR_VERSION}-preview.${hash}"
echo "🏷 Using pre-release tag: $CONNECTOR_VERSION_TAG"
else
CONNECTOR_VERSION_TAG="$CONNECTOR_VERSION"

View File

@@ -21,7 +21,7 @@ As needed or by request, Airbyte Maintainers can execute the following slash com
- `/run-live-tests` - Runs live tests for the modified connector(s).
- `/run-regression-tests` - Runs regression tests for the modified connector(s).
- `/build-connector-images` - Builds and publishes a pre-release docker image for the modified connector(s).
- `/publish-connectors-prerelease` - Publishes pre-release connector builds (tagged as `{version}-dev.{git-sha}`) for all modified connectors in the PR.
- `/publish-connectors-prerelease` - Publishes pre-release connector builds (tagged as `{version}-preview.{git-sha}`) for all modified connectors in the PR.
If you have any questions, feel free to ask in the PR comments or join our [Slack community](https://airbytehq.slack.com/).

View File

@@ -28,7 +28,11 @@ Airbyte Maintainers (that's you!) can execute the following slash commands on yo
- `/run-live-tests` - Runs live tests for the modified connector(s).
- `/run-regression-tests` - Runs regression tests for the modified connector(s).
- `/build-connector-images` - Builds and publishes a pre-release docker image for the modified connector(s).
- `/publish-connectors-prerelease` - Publishes pre-release connector builds (tagged as `{version}-dev.{git-sha}`) for all modified connectors in the PR.
- `/publish-connectors-prerelease` - Publishes pre-release connector builds (tagged as `{version}-preview.{git-sha}`) for all modified connectors in the PR.
- Connector release lifecycle (AI-powered):
- `/ai-prove-fix` - Runs prerelease readiness checks, including testing against customer connections.
- `/ai-canary-prerelease` - Rolls out prerelease to 5-10 connections for canary testing.
- `/ai-release-watch` - Monitors rollout post-release and tracks sync success rates.
- JVM connectors:
- `/update-connector-cdk-version connector=<CONNECTOR_NAME>` - Updates the specified connector to the latest CDK version.
Example: `/update-connector-cdk-version connector=destination-bigquery`

View File

@@ -0,0 +1,72 @@
name: AI Canary Prerelease Command
on:
workflow_dispatch:
inputs:
pr:
description: "Pull request number (if triggered from a PR)"
type: number
required: false
comment-id:
description: "The comment-id of the slash command. Used to update the comment with the status."
required: false
repo:
description: "Repo (passed by slash command dispatcher)"
required: false
default: "airbytehq/airbyte"
gitref:
description: "Git ref (passed by slash command dispatcher)"
required: false
run-name: "AI Canary Prerelease for PR #${{ github.event.inputs.pr }}"
permissions:
contents: read
issues: write
pull-requests: read
jobs:
ai-canary-prerelease:
runs-on: ubuntu-latest
steps:
- name: Get job variables
id: job-vars
run: |
echo "run-url=https://github.com/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID" >> $GITHUB_OUTPUT
- name: Checkout code
uses: actions/checkout@v4
- name: Authenticate as GitHub App
uses: actions/create-github-app-token@v2
id: get-app-token
with:
owner: "airbytehq"
repositories: "airbyte,oncall"
app-id: ${{ secrets.OCTAVIA_BOT_APP_ID }}
private-key: ${{ secrets.OCTAVIA_BOT_PRIVATE_KEY }}
- name: Post start comment
if: inputs.comment-id != ''
uses: peter-evans/create-or-update-comment@v4
with:
token: ${{ steps.get-app-token.outputs.token }}
comment-id: ${{ inputs.comment-id }}
issue-number: ${{ inputs.pr }}
body: |
> **AI Canary Prerelease Started**
>
> Rolling out to 5-10 connections, watching results, and reporting findings.
> [View workflow run](${{ steps.job-vars.outputs.run-url }})
- name: Run AI Canary Prerelease
uses: aaronsteers/devin-action@main
with:
comment-id: ${{ inputs.comment-id }}
issue-number: ${{ inputs.pr }}
playbook-macro: "!canary_prerelease"
devin-token: ${{ secrets.DEVIN_AI_API_KEY }}
github-token: ${{ steps.get-app-token.outputs.token }}
start-message: "🐤 **AI Canary Prerelease session starting...** Rolling out to 5-10 connections, watching results, and reporting findings. [View playbook](https://github.com/airbytehq/oncall/blob/main/prompts/playbooks/canary_prerelease.md)"
tags: |
ai-oncall

View File

@@ -0,0 +1,72 @@
name: AI Prove Fix Command
on:
workflow_dispatch:
inputs:
pr:
description: "Pull request number (if triggered from a PR)"
type: number
required: false
comment-id:
description: "The comment-id of the slash command. Used to update the comment with the status."
required: false
repo:
description: "Repo (passed by slash command dispatcher)"
required: false
default: "airbytehq/airbyte"
gitref:
description: "Git ref (passed by slash command dispatcher)"
required: false
run-name: "AI Prove Fix for PR #${{ github.event.inputs.pr }}"
permissions:
contents: read
issues: write
pull-requests: read
jobs:
ai-prove-fix:
runs-on: ubuntu-latest
steps:
- name: Get job variables
id: job-vars
run: |
echo "run-url=https://github.com/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID" >> $GITHUB_OUTPUT
- name: Checkout code
uses: actions/checkout@v4
- name: Authenticate as GitHub App
uses: actions/create-github-app-token@v2
id: get-app-token
with:
owner: "airbytehq"
repositories: "airbyte,oncall"
app-id: ${{ secrets.OCTAVIA_BOT_APP_ID }}
private-key: ${{ secrets.OCTAVIA_BOT_PRIVATE_KEY }}
- name: Post start comment
if: inputs.comment-id != ''
uses: peter-evans/create-or-update-comment@v4
with:
token: ${{ steps.get-app-token.outputs.token }}
comment-id: ${{ inputs.comment-id }}
issue-number: ${{ inputs.pr }}
body: |
> **AI Prove Fix Started**
>
> Running readiness checks and testing against customer connections.
> [View workflow run](${{ steps.job-vars.outputs.run-url }})
- name: Run AI Prove Fix
uses: aaronsteers/devin-action@main
with:
comment-id: ${{ inputs.comment-id }}
issue-number: ${{ inputs.pr }}
playbook-macro: "!prove_fix"
devin-token: ${{ secrets.DEVIN_AI_API_KEY }}
github-token: ${{ steps.get-app-token.outputs.token }}
start-message: "🔍 **AI Prove Fix session starting...** Running readiness checks and testing against customer connections. [View playbook](https://github.com/airbytehq/oncall/blob/main/prompts/playbooks/prove_fix.md)"
tags: |
ai-oncall

View File

@@ -0,0 +1,72 @@
name: AI Release Watch Command
on:
workflow_dispatch:
inputs:
pr:
description: "Pull request number (if triggered from a PR)"
type: number
required: false
comment-id:
description: "The comment-id of the slash command. Used to update the comment with the status."
required: false
repo:
description: "Repo (passed by slash command dispatcher)"
required: false
default: "airbytehq/airbyte"
gitref:
description: "Git ref (passed by slash command dispatcher)"
required: false
run-name: "AI Release Watch for PR #${{ github.event.inputs.pr }}"
permissions:
contents: read
issues: write
pull-requests: read
jobs:
ai-release-watch:
runs-on: ubuntu-latest
steps:
- name: Get job variables
id: job-vars
run: |
echo "run-url=https://github.com/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID" >> $GITHUB_OUTPUT
- name: Checkout code
uses: actions/checkout@v4
- name: Authenticate as GitHub App
uses: actions/create-github-app-token@v2
id: get-app-token
with:
owner: "airbytehq"
repositories: "airbyte,oncall"
app-id: ${{ secrets.OCTAVIA_BOT_APP_ID }}
private-key: ${{ secrets.OCTAVIA_BOT_PRIVATE_KEY }}
- name: Post start comment
if: inputs.comment-id != ''
uses: peter-evans/create-or-update-comment@v4
with:
token: ${{ steps.get-app-token.outputs.token }}
comment-id: ${{ inputs.comment-id }}
issue-number: ${{ inputs.pr }}
body: |
> **AI Release Watch Started**
>
> Monitoring rollout and tracking sync success rates.
> [View workflow run](${{ steps.job-vars.outputs.run-url }})
- name: Run AI Release Watch
uses: aaronsteers/devin-action@main
with:
comment-id: ${{ inputs.comment-id }}
issue-number: ${{ inputs.pr }}
playbook-macro: "!release_watch"
devin-token: ${{ secrets.DEVIN_AI_API_KEY }}
github-token: ${{ steps.get-app-token.outputs.token }}
start-message: "👁️ **AI Release Watch session starting...** Monitoring rollout and tracking sync success rates. [View playbook](https://github.com/airbytehq/oncall/blob/main/prompts/playbooks/release_watch.md)"
tags: |
ai-oncall

View File

@@ -104,7 +104,7 @@ jobs:
if: steps.check-support-level.outputs.metadata_file == 'true' && steps.check-support-level.outputs.community_support == 'true'
env:
PROMPT_TEXT: "The commit to review is ${{ github.sha }}. This commit was pushed to master and may contain connector changes that need documentation updates."
uses: aaronsteers/devin-action@0d74d6d9ff1b16ada5966dc31af53a9d155759f4 # Pinned to specific commit for security
uses: aaronsteers/devin-action@98d15ae93d1848914f5ab8e9ce45341182958d27 # v0.1.7 - Pinned to specific commit for security
with:
devin-token: ${{ secrets.DEVIN_AI_API_KEY }}
github-token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -0,0 +1,28 @@
name: Label Community PRs
# This workflow automatically adds the "community" label to PRs from forks.
# This enables automatic tracking on the Community PRs project board.
on:
pull_request_target:
types:
- opened
- reopened
jobs:
label-community-pr:
name: Add "Community" Label to PR
# Only run for PRs from forks
if: github.event.pull_request.head.repo.fork == true
runs-on: ubuntu-24.04
permissions:
issues: write
pull-requests: write
steps:
- name: Add community label
# This action uses GitHub's addLabels API, which is idempotent.
# If the label already exists, the API call succeeds without error.
uses: actions-ecosystem/action-add-labels@bd52874380e3909a1ac983768df6976535ece7f8 # v1.1.3
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
labels: community

View File

@@ -3,7 +3,7 @@ name: Publish Connectors Pre-release
# It can be triggered via the /publish-connectors-prerelease slash command from PR comments,
# or via the MCP tool `publish_connector_to_airbyte_registry`.
#
# Pre-release versions are tagged with the format: {version}-dev.{10-char-git-sha}
# Pre-release versions are tagged with the format: {version}-preview.{7-char-git-sha}
# These versions are NOT eligible for semver auto-advancement but ARE available
# for version pinning via the scoped_configuration API.
#
@@ -66,7 +66,7 @@ jobs:
- name: Get short SHA
id: get-sha
run: |
SHORT_SHA=$(git rev-parse --short=10 HEAD)
SHORT_SHA=$(git rev-parse --short=7 HEAD)
echo "short-sha=$SHORT_SHA" >> $GITHUB_OUTPUT
- name: Get job variables
@@ -135,7 +135,7 @@ jobs:
> Publishing pre-release build for connector `${{ steps.resolve-connector.outputs.connector-name }}`.
> Branch: `${{ inputs.gitref }}`
>
> Pre-release versions will be tagged as `{version}-dev.${{ steps.get-sha.outputs.short-sha }}`
> Pre-release versions will be tagged as `{version}-preview.${{ steps.get-sha.outputs.short-sha }}`
> and are available for version pinning via the scoped_configuration API.
>
> [View workflow run](${{ steps.job-vars.outputs.run-url }})
@@ -147,6 +147,7 @@ jobs:
with:
connectors: ${{ format('--name={0}', needs.init.outputs.connector-name) }}
release-type: pre-release
gitref: ${{ inputs.gitref }}
secrets: inherit
post-completion:
@@ -176,13 +177,12 @@ jobs:
id: message-vars
run: |
CONNECTOR_NAME="${{ needs.init.outputs.connector-name }}"
SHORT_SHA="${{ needs.init.outputs.short-sha }}"
VERSION="${{ needs.init.outputs.connector-version }}"
# Use the actual docker-image-tag from the publish workflow output
DOCKER_TAG="${{ needs.publish.outputs.docker-image-tag }}"
if [[ -n "$VERSION" ]]; then
DOCKER_TAG="${VERSION}-dev.${SHORT_SHA}"
else
DOCKER_TAG="{version}-dev.${SHORT_SHA}"
if [[ -z "$DOCKER_TAG" ]]; then
echo "::error::docker-image-tag output is missing from publish workflow. This is unexpected."
exit 1
fi
echo "connector_name=$CONNECTOR_NAME" >> $GITHUB_OUTPUT

View File

@@ -21,6 +21,14 @@ on:
required: false
default: false
type: boolean
gitref:
description: "Git ref (branch or SHA) to build connectors from. Used by pre-release workflow to build from PR branches."
required: false
type: string
outputs:
docker-image-tag:
description: "Docker image tag used when publishing. For single-connector callers only; multi-connector callers should not rely on this output."
value: ${{ jobs.publish_connector_registry_entries.outputs.docker-image-tag }}
workflow_dispatch:
inputs:
connectors:
@@ -48,6 +56,7 @@ jobs:
# v4
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955
with:
ref: ${{ inputs.gitref || '' }}
fetch-depth: 2 # Required so we can conduct a diff from the previous commit to understand what connectors have changed.
submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully.
- name: List connectors to publish [manual]
@@ -105,6 +114,7 @@ jobs:
# v4
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955
with:
ref: ${{ inputs.gitref || '' }}
fetch-depth: 2 # Required so we can conduct a diff from the previous commit to understand what connectors have changed.
submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully.
@@ -250,11 +260,14 @@ jobs:
max-parallel: 5
# Allow all jobs to run, even if one fails
fail-fast: false
outputs:
docker-image-tag: ${{ steps.connector-metadata.outputs.docker-image-tag }}
steps:
- name: Checkout Airbyte
# v4
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955
with:
ref: ${{ inputs.gitref || '' }}
fetch-depth: 2 # Required so we can conduct a diff from the previous commit to understand what connectors have changed.
submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully.
@@ -292,8 +305,8 @@ jobs:
echo "connector-version=$(poe -qq get-version)" | tee -a $GITHUB_OUTPUT
CONNECTOR_VERSION=$(poe -qq get-version)
if [[ "${{ inputs.release-type }}" == "pre-release" ]]; then
hash=$(git rev-parse --short=10 HEAD)
echo "docker-image-tag=${CONNECTOR_VERSION}-dev.${hash}" | tee -a $GITHUB_OUTPUT
hash=$(git rev-parse --short=7 HEAD)
echo "docker-image-tag=${CONNECTOR_VERSION}-preview.${hash}" | tee -a $GITHUB_OUTPUT
echo "release-type-flag=--pre-release" | tee -a $GITHUB_OUTPUT
else
echo "docker-image-tag=${CONNECTOR_VERSION}" | tee -a $GITHUB_OUTPUT
@@ -349,6 +362,7 @@ jobs:
# v4
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955
with:
ref: ${{ inputs.gitref || '' }}
submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully.
- name: Match GitHub User to Slack User
id: match-github-to-slack-user
@@ -381,6 +395,7 @@ jobs:
# v4
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955
with:
ref: ${{ inputs.gitref || '' }}
submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully.
- name: Notify PagerDuty
id: pager-duty

View File

@@ -35,6 +35,9 @@ jobs:
issue-type: both
commands: |
ai-canary-prerelease
ai-prove-fix
ai-release-watch
approve-regression-tests
bump-bulk-cdk-version
bump-progressive-rollout-version

View File

@@ -0,0 +1,70 @@
name: Sync Agent Connector Docs
on:
schedule:
- cron: "0 */2 * * *" # Every 2 hours
workflow_dispatch: # Manual trigger
jobs:
sync-docs:
runs-on: ubuntu-latest
steps:
- name: Checkout airbyte repo
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
- name: Checkout airbyte-agent-connectors
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
repository: airbytehq/airbyte-agent-connectors
path: agent-connectors-source
- name: Sync connector docs
run: |
DEST_DIR="docs/ai-agents/connectors"
mkdir -p "$DEST_DIR"
for connector_dir in agent-connectors-source/connectors/*/; do
connector=$(basename "$connector_dir")
# Only delete/recreate the specific connector subdirectory
# This leaves any files directly in $DEST_DIR untouched
rm -rf "$DEST_DIR/$connector"
mkdir -p "$DEST_DIR/$connector"
# Copy all markdown files for this connector
for md_file in "$connector_dir"/*.md; do
if [ -f "$md_file" ]; then
cp "$md_file" "$DEST_DIR/$connector/"
fi
done
done
echo "Synced $(ls -d $DEST_DIR/*/ 2>/dev/null | wc -l) connectors"
- name: Cleanup temporary checkout
run: rm -rf agent-connectors-source
- name: Authenticate as GitHub App
uses: actions/create-github-app-token@v2
id: get-app-token
with:
owner: "airbytehq"
repositories: "airbyte"
app-id: ${{ secrets.OCTAVIA_BOT_APP_ID }}
private-key: ${{ secrets.OCTAVIA_BOT_PRIVATE_KEY }}
- name: Create PR if changes
uses: peter-evans/create-pull-request@0979079bc20c05bbbb590a56c21c4e2b1d1f1bbe # v6
with:
token: ${{ steps.get-app-token.outputs.token }}
commit-message: "docs: sync agent connector docs from airbyte-agent-connectors repo"
branch: auto-sync-ai-connector-docs
delete-branch: true
title: "docs: sync agent connector docs from airbyte-agent-connectors repo"
body: |
Automated sync of agent connector docs from airbyte-agent-connectors.
This PR was automatically created by the sync-agent-connector-docs workflow.
labels: |
documentation
auto-merge

3
.markdownlintignore Normal file
View File

@@ -0,0 +1,3 @@
# Ignore auto-generated connector documentation files synced from airbyte-agent-connectors repo
# These files are generated and have formatting that doesn't conform to markdownlint rules
docs/ai-agents/connectors/**

View File

@@ -1,3 +1,21 @@
## Version 0.1.91
load cdk: upsert records test uses proper target schema
## Version 0.1.90
load cdk: components tests: data coercion tests cover all data types
## Version 0.1.89
load cdk: components tests: data coercion tests for int+number
## Version 0.1.88
**Load CDK**
* Add CDC_CURSOR_COLUMN_NAME constant.
## Version 0.1.87
**Load CDK**

View File

@@ -4,4 +4,13 @@
package io.airbyte.cdk.load.table
/**
* CDC meta column names.
*
* Note: These CDC column names are brittle as they are separate yet coupled to the logic sources
* use to generate these column names. See
* [io.airbyte.integrations.source.mssql.MsSqlSourceOperations.MsSqlServerCdcMetaFields] for an
* example.
*/
const val CDC_DELETED_AT_COLUMN = "_ab_cdc_deleted_at"
const val CDC_CURSOR_COLUMN = "_ab_cdc_cursor"

View File

@@ -0,0 +1,859 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.component
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.ArrayValue
import io.airbyte.cdk.load.data.DateValue
import io.airbyte.cdk.load.data.IntegerValue
import io.airbyte.cdk.load.data.NullValue
import io.airbyte.cdk.load.data.NumberValue
import io.airbyte.cdk.load.data.ObjectValue
import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.data.TimeWithTimezoneValue
import io.airbyte.cdk.load.data.TimeWithoutTimezoneValue
import io.airbyte.cdk.load.data.TimestampWithTimezoneValue
import io.airbyte.cdk.load.data.TimestampWithoutTimezoneValue
import io.airbyte.cdk.load.dataflow.transform.ValueCoercer
import io.airbyte.cdk.load.util.serializeToString
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Reason
import java.math.BigDecimal
import java.math.BigInteger
import java.time.LocalDate
import java.time.LocalDateTime
import java.time.OffsetDateTime
import java.time.format.DateTimeFormatter
import java.time.format.DateTimeFormatterBuilder
import java.time.format.SignStyle
import java.time.temporal.ChronoField
import org.junit.jupiter.params.provider.Arguments
/*
* This file defines "interesting values" for all data types, along with expected behavior for those values.
* You're free to define your own values/behavior depending on the destination, but it's recommended
* that you try to match behavior to an existing fixture.
*
* Classes also include some convenience functions for JUnit. For example, you could annotate your
* method with:
* ```kotlin
* @ParameterizedTest
* @MethodSource("io.airbyte.cdk.load.component.DataCoercionIntegerFixtures#int64")
* ```
*
* By convention, all fixtures are declared as:
* 1. One or more `val <name>: List<Pair<AirbyteValue, Any?>>` (each pair representing the input value,
* and the expected output value)
* 2. One or more `fun <name>(): List<Arguments> = <name>.toArgs()`, which can be provided to JUnit's MethodSource
*
* If you need to mutate fixtures in some way, you should reference the `val`, and use the `toArgs()`
* extension function to convert it to JUnit's Arguments class. See [DataCoercionIntegerFixtures.int64AsBigInteger]
* for an example.
*/
object DataCoercionIntegerFixtures {
// "9".repeat(38)
val numeric38_0Max = bigint("99999999999999999999999999999999999999")
val numeric38_0Min = bigint("-99999999999999999999999999999999999999")
const val ZERO = "0"
const val ONE = "1"
const val NEGATIVE_ONE = "-1"
const val FORTY_TWO = "42"
const val NEGATIVE_FORTY_TWO = "-42"
const val INT32_MAX = "int32 max"
const val INT32_MIN = "int32 min"
const val INT32_MAX_PLUS_ONE = "int32_max + 1"
const val INT32_MIN_MINUS_ONE = "int32_min - 1"
const val INT64_MAX = "int64 max"
const val INT64_MIN = "int64 min"
const val INT64_MAX_PLUS_ONE = "int64_max + 1"
const val INT64_MIN_MINUS_1 = "int64_min - 1"
const val NUMERIC_38_0_MAX = "numeric(38,0) max"
const val NUMERIC_38_0_MIN = "numeric(38,0) min"
const val NUMERIC_38_0_MAX_PLUS_ONE = "numeric(38,0)_max + 1"
const val NUMERIC_38_0_MIN_MINUS_ONE = "numeric(38,0)_min - 1"
/**
* Many destinations use int64 to represent integers. In this case, we null out any value beyond
* Long.MIN/MAX_VALUE.
*/
val int64 =
listOf(
case(NULL, NullValue, null),
case(ZERO, IntegerValue(0), 0L),
case(ONE, IntegerValue(1), 1L),
case(NEGATIVE_ONE, IntegerValue(-1), -1L),
case(FORTY_TWO, IntegerValue(42), 42L),
case(NEGATIVE_FORTY_TWO, IntegerValue(-42), -42L),
// int32 bounds, and slightly out of bounds
case(INT32_MAX, IntegerValue(Integer.MAX_VALUE.toLong()), Integer.MAX_VALUE.toLong()),
case(INT32_MIN, IntegerValue(Integer.MIN_VALUE.toLong()), Integer.MIN_VALUE.toLong()),
case(
INT32_MAX_PLUS_ONE,
IntegerValue(Integer.MAX_VALUE.toLong() + 1),
Integer.MAX_VALUE.toLong() + 1
),
case(
INT32_MIN_MINUS_ONE,
IntegerValue(Integer.MIN_VALUE.toLong() - 1),
Integer.MIN_VALUE.toLong() - 1
),
// int64 bounds, and slightly out of bounds
case(INT64_MAX, IntegerValue(Long.MAX_VALUE), Long.MAX_VALUE),
case(INT64_MIN, IntegerValue(Long.MIN_VALUE), Long.MIN_VALUE),
// values out of int64 bounds are nulled
case(
INT64_MAX_PLUS_ONE,
IntegerValue(bigint(Long.MAX_VALUE) + BigInteger.ONE),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
INT64_MIN_MINUS_1,
IntegerValue(bigint(Long.MIN_VALUE) - BigInteger.ONE),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
// NUMERIC(38, 9) bounds, and slightly out of bounds
// (these are all out of bounds for an int64 value, so they all get nulled)
case(
NUMERIC_38_0_MAX,
IntegerValue(numeric38_0Max),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
NUMERIC_38_0_MIN,
IntegerValue(numeric38_0Min),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
NUMERIC_38_0_MAX_PLUS_ONE,
IntegerValue(numeric38_0Max + BigInteger.ONE),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
NUMERIC_38_0_MIN_MINUS_ONE,
IntegerValue(numeric38_0Min - BigInteger.ONE),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
)
/**
* Many destination warehouses represent integers as a fixed-point type with 38 digits of
* precision. In this case, we only need to null out numbers larger than `1e38 - 1` / smaller
* than `-1e38 + 1`.
*/
val numeric38_0 =
listOf(
case(NULL, NullValue, null),
case(ZERO, IntegerValue(0), bigint(0L)),
case(ONE, IntegerValue(1), bigint(1L)),
case(NEGATIVE_ONE, IntegerValue(-1), bigint(-1L)),
case(FORTY_TWO, IntegerValue(42), bigint(42L)),
case(NEGATIVE_FORTY_TWO, IntegerValue(-42), bigint(-42L)),
// int32 bounds, and slightly out of bounds
case(
INT32_MAX,
IntegerValue(Integer.MAX_VALUE.toLong()),
bigint(Integer.MAX_VALUE.toLong())
),
case(
INT32_MIN,
IntegerValue(Integer.MIN_VALUE.toLong()),
bigint(Integer.MIN_VALUE.toLong())
),
case(
INT32_MAX_PLUS_ONE,
IntegerValue(Integer.MAX_VALUE.toLong() + 1),
bigint(Integer.MAX_VALUE.toLong() + 1)
),
case(
INT32_MIN_MINUS_ONE,
IntegerValue(Integer.MIN_VALUE.toLong() - 1),
bigint(Integer.MIN_VALUE.toLong() - 1)
),
// int64 bounds, and slightly out of bounds
case(INT64_MAX, IntegerValue(Long.MAX_VALUE), bigint(Long.MAX_VALUE)),
case(INT64_MIN, IntegerValue(Long.MIN_VALUE), bigint(Long.MIN_VALUE)),
case(
INT64_MAX_PLUS_ONE,
IntegerValue(bigint(Long.MAX_VALUE) + BigInteger.ONE),
bigint(Long.MAX_VALUE) + BigInteger.ONE
),
case(
INT64_MIN_MINUS_1,
IntegerValue(bigint(Long.MIN_VALUE) - BigInteger.ONE),
bigint(Long.MIN_VALUE) - BigInteger.ONE
),
// NUMERIC(38, 9) bounds, and slightly out of bounds
case(NUMERIC_38_0_MAX, IntegerValue(numeric38_0Max), numeric38_0Max),
case(NUMERIC_38_0_MIN, IntegerValue(numeric38_0Min), numeric38_0Min),
// These values exceed the 38-digit range, so they get nulled out
case(
NUMERIC_38_0_MAX_PLUS_ONE,
IntegerValue(numeric38_0Max + BigInteger.ONE),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
NUMERIC_38_0_MIN_MINUS_ONE,
IntegerValue(numeric38_0Min - BigInteger.ONE),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
)
@JvmStatic fun int64() = int64.toArgs()
/**
* Convenience fixture if your [TestTableOperationsClient] returns integers as [BigInteger]
* rather than [Long].
*/
@JvmStatic
fun int64AsBigInteger() =
int64.map { it.copy(outputValue = it.outputValue?.let { bigint(it as Long) }) }
/**
* Convenience fixture if your [TestTableOperationsClient] returns integers as [BigDecimal]
* rather than [Long].
*/
@JvmStatic
fun int64AsBigDecimal() =
int64.map { it.copy(outputValue = it.outputValue?.let { BigDecimal.valueOf(it as Long) }) }
@JvmStatic fun numeric38_0() = numeric38_0.toArgs()
}
object DataCoercionNumberFixtures {
val numeric38_9Max = bigdec("99999999999999999999999999999.999999999")
val numeric38_9Min = bigdec("-99999999999999999999999999999.999999999")
const val ZERO = "0"
const val ONE = "1"
const val NEGATIVE_ONE = "-1"
const val ONE_HUNDRED_TWENTY_THREE_POINT_FOUR = "123.4"
const val NEGATIVE_ONE_HUNDRED_TWENTY_THREE_POINT_FOUR = "123.4"
const val POSITIVE_HIGH_PRECISION_FLOAT = "positive high-precision float"
const val NEGATIVE_HIGH_PRECISION_FLOAT = "negative high-precision float"
const val NUMERIC_38_9_MAX = "numeric(38,9) max"
const val NUMERIC_38_9_MIN = "numeric(38,9) min"
const val SMALLEST_POSITIVE_FLOAT32 = "smallest positive float32"
const val SMALLEST_NEGATIVE_FLOAT32 = "smallest negative float32"
const val LARGEST_POSITIVE_FLOAT32 = "largest positive float32"
const val LARGEST_NEGATIVE_FLOAT32 = "largest negative float32"
const val SMALLEST_POSITIVE_FLOAT64 = "smallest positive float64"
const val SMALLEST_NEGATIVE_FLOAT64 = "smallest negative float64"
const val LARGEST_POSITIVE_FLOAT64 = "largest positive float64"
const val LARGEST_NEGATIVE_FLOAT64 = "largest negative float64"
const val SLIGHTLY_ABOVE_LARGEST_POSITIVE_FLOAT64 = "slightly above largest positive float64"
const val SLIGHTLY_BELOW_LARGEST_NEGATIVE_FLOAT64 = "slightly below largest negative float64"
val float64 =
listOf(
case(NULL, NullValue, null),
case(ZERO, NumberValue(bigdec(0)), 0.0),
case(ONE, NumberValue(bigdec(1)), 1.0),
case(NEGATIVE_ONE, NumberValue(bigdec(-1)), -1.0),
// This value isn't exactly representable as a float64
// (the exact value is `123.400000000000005684341886080801486968994140625`)
// but we should preserve the canonical representation
case(ONE_HUNDRED_TWENTY_THREE_POINT_FOUR, NumberValue(bigdec("123.4")), 123.4),
case(
NEGATIVE_ONE_HUNDRED_TWENTY_THREE_POINT_FOUR,
NumberValue(bigdec("-123.4")),
-123.4
),
// These values have too much precision for a float64, so we round them
case(
POSITIVE_HIGH_PRECISION_FLOAT,
NumberValue(bigdec("1234567890.1234567890123456789")),
1234567890.1234567,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
NEGATIVE_HIGH_PRECISION_FLOAT,
NumberValue(bigdec("-1234567890.1234567890123456789")),
-1234567890.1234567,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
NUMERIC_38_9_MAX,
NumberValue(numeric38_9Max),
1.0E29,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
NUMERIC_38_9_MIN,
NumberValue(numeric38_9Min),
-1.0E29,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
// min/max_value are all positive values, so we need to manually test their negative
// version
case(
SMALLEST_POSITIVE_FLOAT32,
NumberValue(bigdec(Float.MIN_VALUE.toDouble())),
Float.MIN_VALUE.toDouble()
),
case(
SMALLEST_NEGATIVE_FLOAT32,
NumberValue(bigdec(-Float.MIN_VALUE.toDouble())),
-Float.MIN_VALUE.toDouble()
),
case(
LARGEST_POSITIVE_FLOAT32,
NumberValue(bigdec(Float.MAX_VALUE.toDouble())),
Float.MAX_VALUE.toDouble()
),
case(
LARGEST_NEGATIVE_FLOAT32,
NumberValue(bigdec(-Float.MAX_VALUE.toDouble())),
-Float.MAX_VALUE.toDouble()
),
case(
SMALLEST_POSITIVE_FLOAT64,
NumberValue(bigdec(Double.MIN_VALUE)),
Double.MIN_VALUE
),
case(
SMALLEST_NEGATIVE_FLOAT64,
NumberValue(bigdec(-Double.MIN_VALUE)),
-Double.MIN_VALUE
),
case(LARGEST_POSITIVE_FLOAT64, NumberValue(bigdec(Double.MAX_VALUE)), Double.MAX_VALUE),
case(
LARGEST_NEGATIVE_FLOAT64,
NumberValue(bigdec(-Double.MAX_VALUE)),
-Double.MAX_VALUE
),
// These values are out of bounds, so we null them
case(
SLIGHTLY_ABOVE_LARGEST_POSITIVE_FLOAT64,
NumberValue(bigdec(Double.MAX_VALUE) + bigdec(Double.MIN_VALUE)),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
SLIGHTLY_BELOW_LARGEST_NEGATIVE_FLOAT64,
NumberValue(bigdec(-Double.MAX_VALUE) - bigdec(Double.MIN_VALUE)),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
)
val numeric38_9 =
listOf(
case(NULL, NullValue, null),
case(ZERO, NumberValue(bigdec(0)), bigdec(0.0)),
case(ONE, NumberValue(bigdec(1)), bigdec(1.0)),
case(NEGATIVE_ONE, NumberValue(bigdec(-1)), bigdec(-1.0)),
// This value isn't exactly representable as a float64
// (the exact value is `123.400000000000005684341886080801486968994140625`)
// but it's perfectly fine as a numeric(38, 9)
case(
ONE_HUNDRED_TWENTY_THREE_POINT_FOUR,
NumberValue(bigdec("123.4")),
bigdec("123.4")
),
case(
NEGATIVE_ONE_HUNDRED_TWENTY_THREE_POINT_FOUR,
NumberValue(bigdec("-123.4")),
bigdec("-123.4")
),
// These values have too much precision for a numeric(38, 9), so we round them
case(
POSITIVE_HIGH_PRECISION_FLOAT,
NumberValue(bigdec("1234567890.1234567890123456789")),
bigdec("1234567890.123456789"),
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
NEGATIVE_HIGH_PRECISION_FLOAT,
NumberValue(bigdec("-1234567890.1234567890123456789")),
bigdec("-1234567890.123456789"),
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
SMALLEST_POSITIVE_FLOAT32,
NumberValue(bigdec(Float.MIN_VALUE.toDouble())),
bigdec(0),
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
SMALLEST_NEGATIVE_FLOAT32,
NumberValue(bigdec(-Float.MIN_VALUE.toDouble())),
bigdec(0),
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
SMALLEST_POSITIVE_FLOAT64,
NumberValue(bigdec(Double.MIN_VALUE)),
bigdec(0),
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
SMALLEST_NEGATIVE_FLOAT64,
NumberValue(bigdec(-Double.MIN_VALUE)),
bigdec(0),
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
// numeric bounds are perfectly fine
case(NUMERIC_38_9_MAX, NumberValue(numeric38_9Max), numeric38_9Max),
case(NUMERIC_38_9_MIN, NumberValue(numeric38_9Min), numeric38_9Min),
// These values are out of bounds, so we null them
case(
LARGEST_POSITIVE_FLOAT32,
NumberValue(bigdec(Float.MAX_VALUE.toDouble())),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
LARGEST_NEGATIVE_FLOAT32,
NumberValue(bigdec(-Float.MAX_VALUE.toDouble())),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
LARGEST_POSITIVE_FLOAT64,
NumberValue(bigdec(Double.MAX_VALUE)),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
LARGEST_NEGATIVE_FLOAT64,
NumberValue(bigdec(-Double.MAX_VALUE)),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
SLIGHTLY_ABOVE_LARGEST_POSITIVE_FLOAT64,
NumberValue(bigdec(Double.MAX_VALUE) + bigdec(Double.MIN_VALUE)),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
SLIGHTLY_BELOW_LARGEST_NEGATIVE_FLOAT64,
NumberValue(bigdec(-Double.MAX_VALUE) - bigdec(Double.MIN_VALUE)),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
)
.map { it.copy(outputValue = (it.outputValue as BigDecimal?)?.setScale(9)) }
@JvmStatic fun float64() = float64.toArgs()
@JvmStatic fun numeric38_9() = numeric38_9.toArgs()
}
const val SIMPLE_TIMESTAMP = "simple timestamp"
const val UNIX_EPOCH = "unix epoch"
const val MINIMUM_TIMESTAMP = "minimum timestamp"
const val MAXIMUM_TIMESTAMP = "maximum timestamp"
const val OUT_OF_RANGE_TIMESTAMP = "out of range timestamp"
const val HIGH_PRECISION_TIMESTAMP = "high-precision timestamp"
object DataCoercionTimestampTzFixtures {
/**
* Many warehouses support timestamps between years 0001 - 9999.
*
* Depending on the exact warehouse, you may need to tweak the precision on some values. For
* example, Snowflake supports nanoseconds-precision timestamps (9 decimal points), but Bigquery
* only supports microseconds-precision (6 decimal points). Bigquery would probably do something
* like:
* ```kotlin
* DataCoercionNumberFixtures.traditionalWarehouse
* .map {
* when (it.name) {
* "maximum AD timestamp" -> it.copy(
* inputValue = TimestampWithTimezoneValue("9999-12-31T23:59:59.999999Z"),
* outputValue = OffsetDateTime.parse("9999-12-31T23:59:59.999999Z"),
* changeReason = Reason.DESTINATION_FIELD_SIZE_LIMITATION,
* )
* "high-precision timestamp" -> it.copy(
* outputValue = OffsetDateTime.parse("2025-01-23T01:01:00.123456Z"),
* changeReason = Reason.DESTINATION_FIELD_SIZE_LIMITATION,
* )
* }
* }
* ```
*/
val commonWarehouse =
listOf(
case(NULL, NullValue, null),
case(
SIMPLE_TIMESTAMP,
TimestampWithTimezoneValue("2025-01-23T12:34:56.789Z"),
"2025-01-23T12:34:56.789Z",
),
case(
UNIX_EPOCH,
TimestampWithTimezoneValue("1970-01-01T00:00:00Z"),
"1970-01-01T00:00:00Z",
),
case(
MINIMUM_TIMESTAMP,
TimestampWithTimezoneValue("0001-01-01T00:00:00Z"),
"0001-01-01T00:00:00Z",
),
case(
MAXIMUM_TIMESTAMP,
TimestampWithTimezoneValue("9999-12-31T23:59:59.999999999Z"),
"9999-12-31T23:59:59.999999999Z",
),
case(
OUT_OF_RANGE_TIMESTAMP,
TimestampWithTimezoneValue(odt("10000-01-01T00:00Z")),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION,
),
case(
HIGH_PRECISION_TIMESTAMP,
TimestampWithTimezoneValue("2025-01-23T01:01:00.123456789Z"),
"2025-01-23T01:01:00.123456789Z",
),
)
@JvmStatic fun commonWarehouse() = commonWarehouse.toArgs()
}
object DataCoercionTimestampNtzFixtures {
/** See [DataCoercionTimestampTzFixtures.commonWarehouse] for explanation */
val commonWarehouse =
listOf(
case(NULL, NullValue, null),
case(
SIMPLE_TIMESTAMP,
TimestampWithoutTimezoneValue("2025-01-23T12:34:56.789"),
"2025-01-23T12:34:56.789",
),
case(
UNIX_EPOCH,
TimestampWithoutTimezoneValue("1970-01-01T00:00:00"),
"1970-01-01T00:00:00",
),
case(
MINIMUM_TIMESTAMP,
TimestampWithoutTimezoneValue("0001-01-01T00:00:00"),
"0001-01-01T00:00:00",
),
case(
MAXIMUM_TIMESTAMP,
TimestampWithoutTimezoneValue("9999-12-31T23:59:59.999999999"),
"9999-12-31T23:59:59.999999999",
),
case(
OUT_OF_RANGE_TIMESTAMP,
TimestampWithoutTimezoneValue(ldt("10000-01-01T00:00")),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION,
),
case(
HIGH_PRECISION_TIMESTAMP,
TimestampWithoutTimezoneValue("2025-01-23T01:01:00.123456789"),
"2025-01-23T01:01:00.123456789",
),
)
@JvmStatic fun commonWarehouse() = commonWarehouse.toArgs()
}
const val MIDNIGHT = "midnight"
const val MAX_TIME = "max time"
const val HIGH_NOON = "high noon"
object DataCoercionTimeTzFixtures {
val timetz =
listOf(
case(NULL, NullValue, null),
case(MIDNIGHT, TimeWithTimezoneValue("00:00Z"), "00:00Z"),
case(MAX_TIME, TimeWithTimezoneValue("23:59:59.999999999Z"), "23:59:59.999999999Z"),
case(HIGH_NOON, TimeWithTimezoneValue("12:00Z"), "12:00Z"),
)
@JvmStatic fun timetz() = timetz.toArgs()
}
object DataCoercionTimeNtzFixtures {
val timentz =
listOf(
case(NULL, NullValue, null),
case(MIDNIGHT, TimeWithoutTimezoneValue("00:00"), "00:00"),
case(MAX_TIME, TimeWithoutTimezoneValue("23:59:59.999999999"), "23:59:59.999999999"),
case(HIGH_NOON, TimeWithoutTimezoneValue("12:00"), "12:00"),
)
@JvmStatic fun timentz() = timentz.toArgs()
}
object DataCoercionDateFixtures {
val commonWarehouse =
listOf(
case(NULL, NullValue, null),
case(
SIMPLE_TIMESTAMP,
DateValue("2025-01-23"),
"2025-01-23",
),
case(
UNIX_EPOCH,
DateValue("1970-01-01"),
"1970-01-01",
),
case(
MINIMUM_TIMESTAMP,
DateValue("0001-01-01"),
"0001-01-01",
),
case(
MAXIMUM_TIMESTAMP,
DateValue("9999-12-31"),
"9999-12-31",
),
case(
OUT_OF_RANGE_TIMESTAMP,
DateValue(date("10000-01-01")),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION,
),
)
@JvmStatic fun commonWarehouse() = commonWarehouse.toArgs()
}
object DataCoercionStringFixtures {
const val EMPTY_STRING = "empty string"
const val SHORT_STRING = "short string"
const val LONG_STRING = "long string"
const val SPECIAL_CHARS_STRING = "special chars string"
val strings =
listOf(
case(NULL, NullValue, null),
case(EMPTY_STRING, StringValue(""), ""),
case(SHORT_STRING, StringValue("foo"), "foo"),
// Implementers may override this to test their destination-specific limits.
// The default value is 8MB + 1 byte (slightly longer than snowflake's varchar limit).
case(
LONG_STRING,
StringValue("a".repeat(16777216 + 1)),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
SPECIAL_CHARS_STRING,
StringValue("`~!@#$%^&*()-=_+[]\\{}|o'O\",./<>?)Δ⅀↑∀"),
"`~!@#$%^&*()-=_+[]\\{}|o'O\",./<>?)Δ⅀↑∀"
),
)
@JvmStatic fun strings() = strings.toArgs()
}
object DataCoercionObjectFixtures {
const val EMPTY_OBJECT = "empty object"
const val NORMAL_OBJECT = "normal object"
val objects =
listOf(
case(NULL, NullValue, null),
case(EMPTY_OBJECT, ObjectValue(linkedMapOf()), emptyMap<String, Any?>()),
case(
NORMAL_OBJECT,
ObjectValue(linkedMapOf("foo" to StringValue("bar"))),
mapOf("foo" to "bar")
),
)
val stringifiedObjects =
objects.map { fixture ->
fixture.copy(outputValue = fixture.outputValue?.serializeToString())
}
@JvmStatic fun objects() = objects.toArgs()
@JvmStatic fun stringifiedObjects() = stringifiedObjects.toArgs()
}
object DataCoercionArrayFixtures {
const val EMPTY_ARRAY = "empty array"
const val NORMAL_ARRAY = "normal array"
val arrays =
listOf(
case(NULL, NullValue, null),
case(EMPTY_ARRAY, ArrayValue(emptyList()), emptyList<Any?>()),
case(NORMAL_ARRAY, ArrayValue(listOf(StringValue("foo"))), listOf("foo")),
)
val stringifiedArrays =
arrays.map { fixture ->
fixture.copy(outputValue = fixture.outputValue?.serializeToString())
}
@JvmStatic fun arrays() = arrays.toArgs()
@JvmStatic fun stringifiedArrays() = stringifiedArrays.toArgs()
}
const val UNION_INT_VALUE = "int value"
const val UNION_OBJ_VALUE = "object value"
const val UNION_STR_VALUE = "string value"
object DataCoercionUnionFixtures {
val unions =
listOf(
case(NULL, NullValue, null),
case(UNION_INT_VALUE, IntegerValue(42), 42L),
case(UNION_STR_VALUE, StringValue("foo"), "foo"),
case(
UNION_OBJ_VALUE,
ObjectValue(linkedMapOf("foo" to StringValue("bar"))),
mapOf("foo" to "bar")
),
)
val stringifiedUnions =
unions.map { fixture ->
fixture.copy(outputValue = fixture.outputValue?.serializeToString())
}
@JvmStatic fun unions() = unions.toArgs()
@JvmStatic fun stringifiedUnions() = stringifiedUnions.toArgs()
}
object DataCoercionLegacyUnionFixtures {
val unions =
listOf(
case(NULL, NullValue, null),
// Legacy union of int x object will select object, and you can't write an int to an
// object column.
// So we should null it out.
case(UNION_INT_VALUE, IntegerValue(42), null, Reason.DESTINATION_TYPECAST_ERROR),
// Similarly, we should null out strings.
case(UNION_STR_VALUE, StringValue("foo"), "foo"),
// But objects can be written as objects, so retain this value.
case(
UNION_OBJ_VALUE,
ObjectValue(linkedMapOf("foo" to StringValue("bar"))),
mapOf("foo" to "bar")
),
)
val stringifiedUnions =
DataCoercionUnionFixtures.unions.map { fixture ->
fixture.copy(outputValue = fixture.outputValue?.serializeToString())
}
@JvmStatic fun unions() = unions.toArgs()
@JvmStatic fun stringifiedUnions() = DataCoercionUnionFixtures.stringifiedUnions.toArgs()
}
// This is pretty much identical to UnionFixtures, but separating them in case we need to add
// different test cases for either of them.
object DataCoercionUnknownFixtures {
const val INT_VALUE = "integer value"
const val STR_VALUE = "string value"
const val OBJ_VALUE = "object value"
val unknowns =
listOf(
case(NULL, NullValue, null),
case(INT_VALUE, IntegerValue(42), 42L),
case(STR_VALUE, StringValue("foo"), "foo"),
case(
OBJ_VALUE,
ObjectValue(linkedMapOf("foo" to StringValue("bar"))),
mapOf("foo" to "bar")
),
)
val stringifiedUnknowns =
unknowns.map { fixture ->
fixture.copy(outputValue = fixture.outputValue?.serializeToString())
}
@JvmStatic fun unknowns() = unknowns.toArgs()
@JvmStatic fun stringifiedUnknowns() = stringifiedUnknowns.toArgs()
}
fun List<DataCoercionTestCase>.toArgs(): List<Arguments> =
this.map { Arguments.argumentSet(it.name, it.inputValue, it.outputValue, it.changeReason) }
.toList()
/**
* Utility method to use the BigDecimal constructor (supports exponential notation like `1e38`) to
* construct a BigInteger.
*/
fun bigint(str: String): BigInteger = BigDecimal(str).toBigIntegerExact()
/** Shorthand utility method to construct a bigint from a long */
fun bigint(long: Long): BigInteger = BigInteger.valueOf(long)
fun bigdec(str: String): BigDecimal = BigDecimal(str)
fun bigdec(double: Double): BigDecimal = BigDecimal.valueOf(double)
fun bigdec(int: Int): BigDecimal = BigDecimal.valueOf(int.toDouble())
fun odt(str: String): OffsetDateTime = OffsetDateTime.parse(str, dateTimeFormatter)
fun ldt(str: String): LocalDateTime = LocalDateTime.parse(str, dateTimeFormatter)
fun date(str: String): LocalDate = LocalDate.parse(str, dateFormatter)
// The default java.time.*.parse() behavior only accepts up to 4-digit years.
// Build a custom formatter to handle larger years.
val dateFormatter =
DateTimeFormatterBuilder()
// java.time.* supports up to 9-digit years
.appendValue(ChronoField.YEAR, 1, 9, SignStyle.NORMAL)
.appendLiteral('-')
.appendValue(ChronoField.MONTH_OF_YEAR)
.appendLiteral('-')
.appendValue(ChronoField.DAY_OF_MONTH)
.toFormatter()
val dateTimeFormatter =
DateTimeFormatterBuilder()
.append(dateFormatter)
.appendLiteral('T')
// Accepts strings with/without an offset, so we can use this formatter
// for both timestamp with and without timezone
.append(DateTimeFormatter.ISO_TIME)
.toFormatter()
/**
* Represents a single data coercion test case. You probably want to use [case] as a shorthand
* constructor.
*
* @param name A short human-readable name for the test. Primarily useful for tests where
* [inputValue] is either very long, or otherwise hard to read.
* @param inputValue The value to pass into [ValueCoercer.validate]
* @param outputValue The value that we expect to read back from the destination. Should be
* basically equivalent to the output of [ValueCoercer.validate]
* @param changeReason If `validate` returns Truncate/Nullify, the reason for that
* truncation/nullification. If `validate` returns Valid, this should be null.
*/
data class DataCoercionTestCase(
val name: String,
val inputValue: AirbyteValue,
val outputValue: Any?,
val changeReason: Reason? = null,
)
fun case(
name: String,
inputValue: AirbyteValue,
outputValue: Any?,
changeReason: Reason? = null,
) = DataCoercionTestCase(name, inputValue, outputValue, changeReason)
const val NULL = "null"

View File

@@ -0,0 +1,369 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.component
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.ArrayType
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
import io.airbyte.cdk.load.data.BooleanType
import io.airbyte.cdk.load.data.BooleanValue
import io.airbyte.cdk.load.data.DateType
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerType
import io.airbyte.cdk.load.data.NumberType
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
import io.airbyte.cdk.load.data.UnionType
import io.airbyte.cdk.load.data.UnknownType
import io.airbyte.cdk.load.dataflow.transform.ValueCoercer
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.TableSchemaFactory
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.util.Jsons
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Reason
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import kotlinx.coroutines.test.runTest
/**
* The tests in this class are designed to reference the parameters defined in
* `DataCoercionFixtures.kt`. For example, you might annotate [`handle integer values`] with
* `@MethodSource("io.airbyte.cdk.load.component.DataCoercionIntegerFixtures#int32")`. See each
* fixture class for explanations of what behavior they are exercising.
*
* Note that this class _only_ exercises [ValueCoercer.validate]. You should write separate unit
* tests for [ValueCoercer.map]. For now, the `map` function is primarily intended for transforming
* `UnionType` fields into other types (typically `StringType`), at which point your `validate`
* implementation should be able to handle any StringValue (regardless of whether it was originally
* a StringType or UnionType).
*/
@MicronautTest(environments = ["component"], resolveParameters = false)
interface DataCoercionSuite {
val coercer: ValueCoercer
val airbyteMetaColumnMapping: Map<String, String>
get() = Meta.COLUMN_NAMES.associateWith { it }
val columnNameMapping: ColumnNameMapping
get() = ColumnNameMapping(mapOf("test" to "test"))
val opsClient: TableOperationsClient
val testClient: TestTableOperationsClient
val schemaFactory: TableSchemaFactory
val harness: TableOperationsTestHarness
get() =
TableOperationsTestHarness(
opsClient,
testClient,
schemaFactory,
airbyteMetaColumnMapping
)
/** Fixtures are defined in [DataCoercionIntegerFixtures]. */
fun `handle integer values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(IntegerType, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionNumberFixtures]. */
fun `handle number values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(NumberType, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionTimestampTzFixtures]. */
fun `handle timestamptz values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(TimestampTypeWithTimezone, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionTimestampNtzFixtures]. */
fun `handle timestampntz values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(TimestampTypeWithoutTimezone, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionTimeTzFixtures]. */
fun `handle timetz values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(TimeTypeWithTimezone, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionTimeNtzFixtures]. */
fun `handle timentz values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(TimeTypeWithoutTimezone, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionDateFixtures]. */
fun `handle date values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(DateType, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** No fixtures, hardcoded to just write `true` */
fun `handle bool values`(expectedValue: Any?) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(BooleanType, nullable = true),
// Just test on `true` and assume `false` also works
BooleanValue(true),
expectedValue,
// If your destination is nulling/truncating booleans... that's almost definitely a bug
expectedChangeReason = null,
)
}
/** Fixtures are defined in [DataCoercionStringFixtures]. */
fun `handle string values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(StringType, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionObjectFixtures]. */
fun `handle object values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(
ObjectType(linkedMapOf("foo" to FieldType(StringType, true))),
nullable = true
),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionObjectFixtures]. */
fun `handle empty object values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(ObjectTypeWithEmptySchema, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionObjectFixtures]. */
fun `handle schemaless object values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(ObjectTypeWithoutSchema, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionArrayFixtures]. */
fun `handle array values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(ArrayType(FieldType(StringType, true)), nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionArrayFixtures]. */
fun `handle schemaless array values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(ArrayTypeWithoutSchema, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/**
* All destinations should implement this, even if your destination is supporting legacy unions.
*
* Fixtures are defined in [DataCoercionUnionFixtures].
*/
fun `handle union values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(
UnionType(
setOf(
ObjectType(linkedMapOf("foo" to FieldType(StringType, true))),
IntegerType,
StringType,
),
isLegacyUnion = false
),
nullable = true
),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/**
* Only legacy destinations that are maintaining "legacy" union behavior should implement this
* test. If you're not sure, check whether your `application-connector.yaml` includes a
* `airbyte.destination.core.types.unions: LEGACY` property.
*
* Fixtures are defined in [DataCoercionLegacyUnionFixtures].
*/
fun `handle legacy union values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(
UnionType(
setOf(
ObjectType(linkedMapOf("foo" to FieldType(StringType, true))),
IntegerType,
StringType,
),
isLegacyUnion = true
),
nullable = true
),
inputValue,
expectedValue,
expectedChangeReason,
)
}
fun `handle unknown values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(UnknownType(Jsons.readTree(("""{"type": "potato"}"""))), nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
}

View File

@@ -23,6 +23,7 @@ import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
import io.airbyte.cdk.load.data.TimestampWithTimezoneValue
import io.airbyte.cdk.load.data.UnionType
import io.airbyte.cdk.load.data.UnknownType
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
@@ -84,6 +85,18 @@ object TableOperationsFixtures {
"array" to FieldType(ArrayType(FieldType(StringType, true)), true),
"object" to
FieldType(ObjectType(linkedMapOf("key" to FieldType(StringType, true))), true),
"union" to
FieldType(
UnionType(setOf(StringType, IntegerType), isLegacyUnion = false),
true
),
// Most destinations just ignore the isLegacyUnion flag, which is totally fine.
// This is here for the small set of connectors that respect it.
"legacy_union" to
FieldType(
UnionType(setOf(StringType, IntegerType), isLegacyUnion = true),
true
),
"unknown" to FieldType(UnknownType(Jsons.readTree("""{"type": "potato"}""")), true),
),
)
@@ -101,6 +114,8 @@ object TableOperationsFixtures {
"time_ntz" to "time_ntz",
"array" to "array",
"object" to "object",
"union" to "union",
"legacy_union" to "legacy_union",
"unknown" to "unknown",
)
)
@@ -714,6 +729,11 @@ object TableOperationsFixtures {
return map { record -> record.mapKeys { (k, _) -> totalMapping.invert()[k] ?: k } }
}
fun <V> List<Map<String, V>>.removeAirbyteColumns(
airbyteMetaColumnMapping: Map<String, String>
): List<Map<String, V>> =
this.map { rec -> rec.filter { !airbyteMetaColumnMapping.containsValue(it.key) } }
fun <V> List<Map<String, V>>.removeNulls() =
this.map { record -> record.filterValues { it != null } }

View File

@@ -58,7 +58,8 @@ interface TableOperationsSuite {
get() = Meta.COLUMN_NAMES.associateWith { it }
private val harness: TableOperationsTestHarness
get() = TableOperationsTestHarness(client, testClient, airbyteMetaColumnMapping)
get() =
TableOperationsTestHarness(client, testClient, schemaFactory, airbyteMetaColumnMapping)
/** Tests basic database connectivity by pinging the database. */
fun `connect to database`() = runTest { assertDoesNotThrow { testClient.ping() } }
@@ -606,7 +607,7 @@ interface TableOperationsSuite {
val targetTableSchema =
schemaFactory.make(
targetTable,
Fixtures.TEST_INTEGER_SCHEMA.properties,
Fixtures.ID_TEST_WITH_CDC_SCHEMA.properties,
Dedupe(
primaryKey = listOf(listOf(Fixtures.ID_FIELD)),
cursor = listOf(Fixtures.TEST_FIELD),

View File

@@ -4,11 +4,24 @@
package io.airbyte.cdk.load.component
import io.airbyte.cdk.load.command.Append
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.component.TableOperationsFixtures.inputRecord
import io.airbyte.cdk.load.component.TableOperationsFixtures.insertRecords
import io.airbyte.cdk.load.component.TableOperationsFixtures.removeAirbyteColumns
import io.airbyte.cdk.load.component.TableOperationsFixtures.removeNulls
import io.airbyte.cdk.load.component.TableOperationsFixtures.reverseColumnNameMapping
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.EnrichedAirbyteValue
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.NullValue
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.dataflow.transform.ValidationResult
import io.airbyte.cdk.load.dataflow.transform.ValueCoercer
import io.airbyte.cdk.load.schema.TableSchemaFactory
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Reason
import io.github.oshai.kotlinlogging.KotlinLogging
import org.junit.jupiter.api.Assertions.assertEquals
@@ -21,6 +34,7 @@ private val log = KotlinLogging.logger {}
class TableOperationsTestHarness(
private val client: TableOperationsClient,
private val testClient: TestTableOperationsClient,
private val schemaFactory: TableSchemaFactory,
private val airbyteMetaColumnMapping: Map<String, String>,
) {
@@ -100,8 +114,77 @@ class TableOperationsTestHarness(
/** Reads records from a table, filtering out Meta columns. */
suspend fun readTableWithoutMetaColumns(tableName: TableName): List<Map<String, Any>> {
val tableRead = testClient.readTable(tableName)
return tableRead.map { rec ->
rec.filter { !airbyteMetaColumnMapping.containsValue(it.key) }
return tableRead.removeAirbyteColumns(airbyteMetaColumnMapping)
}
/** Apply the coercer to a value and verify that we can write the coerced value correctly */
suspend fun testValueCoercion(
coercer: ValueCoercer,
columnNameMapping: ColumnNameMapping,
fieldType: FieldType,
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?,
) {
val testNamespace = TableOperationsFixtures.generateTestNamespace("test")
val tableName =
TableOperationsFixtures.generateTestTableName("table-test-table", testNamespace)
val schema = ObjectType(linkedMapOf("test" to fieldType))
val tableSchema = schemaFactory.make(tableName, schema.properties, Append)
val stream =
TableOperationsFixtures.createStream(
namespace = tableName.namespace,
name = tableName.name,
tableSchema = tableSchema,
)
val inputValueAsEnrichedAirbyteValue =
EnrichedAirbyteValue(
inputValue,
fieldType.type,
"test",
airbyteMetaField = null,
)
val validatedValue = coercer.validate(inputValueAsEnrichedAirbyteValue)
val valueToInsert: AirbyteValue
val changeReason: Reason?
when (validatedValue) {
is ValidationResult.ShouldNullify -> {
valueToInsert = NullValue
changeReason = validatedValue.reason
}
is ValidationResult.ShouldTruncate -> {
valueToInsert = validatedValue.truncatedValue
changeReason = validatedValue.reason
}
ValidationResult.Valid -> {
valueToInsert = inputValue
changeReason = null
}
}
client.createNamespace(testNamespace)
client.createTable(stream, tableName, columnNameMapping, replace = false)
testClient.insertRecords(
tableName,
columnNameMapping,
inputRecord("test" to valueToInsert),
)
val actualRecords =
testClient
.readTable(tableName)
.removeAirbyteColumns(airbyteMetaColumnMapping)
.reverseColumnNameMapping(columnNameMapping, airbyteMetaColumnMapping)
.removeNulls()
val actualValue = actualRecords.first()["test"]
assertEquals(
expectedValue,
actualValue,
"For input $inputValue, expected ${expectedValue.simpleClassName()}; actual value was ${actualValue.simpleClassName()}. Coercer output was $validatedValue.",
)
assertEquals(expectedChangeReason, changeReason)
}
}
fun Any?.simpleClassName() = this?.let { it::class.simpleName } ?: "null"

View File

@@ -44,7 +44,13 @@ interface TableSchemaEvolutionSuite {
val schemaFactory: TableSchemaFactory
private val harness: TableOperationsTestHarness
get() = TableOperationsTestHarness(opsClient, testClient, airbyteMetaColumnMapping)
get() =
TableOperationsTestHarness(
opsClient,
testClient,
schemaFactory,
airbyteMetaColumnMapping
)
/**
* Test that the connector can correctly discover all of its own data types. This test creates a

View File

@@ -1 +1 @@
version=0.1.87
version=0.1.91

View File

@@ -10,5 +10,6 @@ CONNECTOR_PATH_PREFIXES = {
"airbyte-integrations/connectors",
"docs/integrations/sources",
"docs/integrations/destinations",
"docs/ai-agents/connectors",
}
MERGE_METHOD = "squash"

View File

@@ -75,7 +75,7 @@ This will copy the specified connector version to your development bucket. This
_💡 Note: A prerequisite is you have [gsutil](https://cloud.google.com/storage/docs/gsutil) installed and have run `gsutil auth login`_
```bash
TARGET_BUCKET=<YOUR-DEV_BUCKET> CONNECTOR="airbyte/source-stripe" VERSION="3.17.0-dev.ea013c8741" poetry run poe copy-connector-from-prod
TARGET_BUCKET=<YOUR-DEV_BUCKET> CONNECTOR="airbyte/source-stripe" VERSION="3.17.0-preview.ea013c8" poetry run poe copy-connector-from-prod
```
### Promote Connector Version to Latest
@@ -87,5 +87,5 @@ _💡 Note: A prerequisite is you have [gsutil](https://cloud.google.com/storage
_⚠️ Warning: Its important to know that this will remove ANY existing files in the latest folder that are not in the versioned folder as it calls `gsutil rsync` with `-d` enabled._
```bash
TARGET_BUCKET=<YOUR-DEV_BUCKET> CONNECTOR="airbyte/source-stripe" VERSION="3.17.0-dev.ea013c8741" poetry run poe promote-connector-to-latest
TARGET_BUCKET=<YOUR-DEV_BUCKET> CONNECTOR="airbyte/source-stripe" VERSION="3.17.0-preview.ea013c8" poetry run poe promote-connector-to-latest
```

View File

@@ -28,8 +28,8 @@ def get_docker_hub_auth_token() -> str:
def get_docker_hub_headers() -> Dict | None:
if "DOCKER_HUB_USERNAME" not in os.environ or "DOCKER_HUB_PASSWORD" not in os.environ:
# If the Docker Hub credentials are not provided, we can only anonymously call the Docker Hub API.
if not os.environ.get("DOCKER_HUB_USERNAME") or not os.environ.get("DOCKER_HUB_PASSWORD"):
# If the Docker Hub credentials are not provided (or are empty), we can only anonymously call the Docker Hub API.
# This will only work for public images and lead to a lower rate limit.
return {}
else:

View File

@@ -434,7 +434,7 @@ def generate_and_persist_registry_entry(
bucket_name (str): The name of the GCS bucket.
repo_metadata_file_path (pathlib.Path): The path to the spec file.
registry_type (str): The registry type.
docker_image_tag (str): The docker image tag associated with this release. Typically a semver string (e.g. '1.2.3'), possibly with a suffix (e.g. '1.2.3-dev.abcde12345')
docker_image_tag (str): The docker image tag associated with this release. Typically a semver string (e.g. '1.2.3'), possibly with a suffix (e.g. '1.2.3-preview.abcde12')
is_prerelease (bool): Whether this is a prerelease, or a main release.
"""
# Read the repo metadata dict to bootstrap ourselves. We need the docker repository,
@@ -444,7 +444,7 @@ def generate_and_persist_registry_entry(
try:
# Now that we have the docker repo, read the appropriate versioned metadata from GCS.
# This metadata will differ in a few fields (e.g. in prerelease mode, dockerImageTag will contain the actual prerelease tag `1.2.3-dev.abcde12345`),
# This metadata will differ in a few fields (e.g. in prerelease mode, dockerImageTag will contain the actual prerelease tag `1.2.3-preview.abcde12`),
# so we'll treat this as the source of truth (ish. See below for how we handle the registryOverrides field.)
gcs_client = get_gcs_storage_client(gcs_creds=os.environ.get("GCS_CREDENTIALS"))
bucket = gcs_client.bucket(bucket_name)
@@ -533,7 +533,9 @@ def generate_and_persist_registry_entry(
# For latest versions that are disabled, delete any existing registry entry to remove it from the registry
if (
"-rc" not in metadata_dict["data"]["dockerImageTag"] and "-dev" not in metadata_dict["data"]["dockerImageTag"]
"-rc" not in metadata_dict["data"]["dockerImageTag"]
and "-dev" not in metadata_dict["data"]["dockerImageTag"]
and "-preview" not in metadata_dict["data"]["dockerImageTag"]
) and not metadata_dict["data"]["registryOverrides"][registry_type]["enabled"]:
logger.info(
f"{registry_type} is not enabled: deleting existing {registry_type} registry entry for {metadata_dict['data']['dockerRepository']} at latest path."

View File

@@ -5,7 +5,7 @@ data:
connectorType: source
dockerRepository: airbyte/image-exists-1
githubIssueLabel: source-alloydb-strict-encrypt
dockerImageTag: 2.0.0-dev.cf3628ccf3
dockerImageTag: 2.0.0-preview.cf3628c
documentationUrl: https://docs.airbyte.com/integrations/sources/existingsource
connectorSubtype: database
releaseStage: generally_available

View File

@@ -231,7 +231,7 @@ def test_upload_prerelease(mocker, valid_metadata_yaml_files, tmp_path):
mocker.patch.object(commands.click, "secho")
mocker.patch.object(commands, "upload_metadata_to_gcs")
prerelease_tag = "0.3.0-dev.6d33165120"
prerelease_tag = "0.3.0-preview.6d33165"
bucket = "my-bucket"
metadata_file_path = valid_metadata_yaml_files[0]
validator_opts = ValidatorOptions(docs_path=str(tmp_path), prerelease_tag=prerelease_tag)

View File

@@ -582,7 +582,7 @@ def test_upload_metadata_to_gcs_invalid_docker_images(mocker, invalid_metadata_u
def test_upload_metadata_to_gcs_with_prerelease(mocker, valid_metadata_upload_files, tmp_path):
mocker.spy(gcs_upload, "_file_upload")
mocker.spy(gcs_upload, "upload_file_if_changed")
prerelease_image_tag = "1.5.6-dev.f80318f754"
prerelease_image_tag = "1.5.6-preview.f80318f"
for valid_metadata_upload_file in valid_metadata_upload_files:
tmp_metadata_file_path = tmp_path / "metadata.yaml"
@@ -701,7 +701,7 @@ def test_upload_metadata_to_gcs_release_candidate(mocker, get_fixture_path, tmp_
)
assert metadata.data.releases.rolloutConfiguration.enableProgressiveRollout
prerelease_tag = "1.5.6-dev.f80318f754" if prerelease else None
prerelease_tag = "1.5.6-preview.f80318f" if prerelease else None
upload_info = gcs_upload.upload_metadata_to_gcs(
"my_bucket",

View File

@@ -110,14 +110,14 @@ class PublishConnectorContext(ConnectorContext):
@property
def pre_release_suffix(self) -> str:
return self.git_revision[:10]
return self.git_revision[:7]
@property
def docker_image_tag(self) -> str:
# get the docker image tag from the parent class
metadata_tag = super().docker_image_tag
if self.pre_release:
return f"{metadata_tag}-dev.{self.pre_release_suffix}"
return f"{metadata_tag}-preview.{self.pre_release_suffix}"
else:
return metadata_tag

View File

@@ -25,7 +25,7 @@ from pipelines.helpers.utils import raise_if_not_user
from pipelines.models.steps import STEP_PARAMS, Step, StepResult
# Pin the PyAirbyte version to avoid updates from breaking CI
PYAIRBYTE_VERSION = "0.20.2"
PYAIRBYTE_VERSION = "0.35.1"
class PytestStep(Step, ABC):

View File

@@ -156,7 +156,8 @@ class TestPyAirbyteValidationTests:
result = await PyAirbyteValidation(context_for_valid_connector)._run(mocker.MagicMock())
assert isinstance(result, StepResult)
assert result.status == StepStatus.SUCCESS
assert "Getting `spec` output from connector..." in result.stdout
# Verify the connector name appears in output (stable across PyAirbyte versions)
assert context_for_valid_connector.connector.technical_name in (result.stdout + result.stderr)
async def test__run_validation_skip_unpublished_connector(
self,

View File

@@ -1,2 +1,2 @@
cdkVersion=0.1.86
cdkVersion=0.1.89
JunitMethodExecutionTimeout=10m

View File

@@ -2,7 +2,7 @@ data:
connectorSubtype: database
connectorType: destination
definitionId: ce0d828e-1dc4-496c-b122-2da42e637e48
dockerImageTag: 2.1.16-rc.2
dockerImageTag: 2.1.18
dockerRepository: airbyte/destination-clickhouse
githubIssueLabel: destination-clickhouse
icon: clickhouse.svg
@@ -27,7 +27,7 @@ data:
releaseStage: generally_available
releases:
rolloutConfiguration:
enableProgressiveRollout: true
enableProgressiveRollout: false
breakingChanges:
2.0.0:
message: "This connector has been re-written from scratch. Data will now be typed and stored in final (non-raw) tables. The connector may require changes to its configuration to function properly and downstream pipelines may be affected. Warning: SSH tunneling is in Beta."

View File

@@ -54,8 +54,11 @@ class ClickhouseSqlGenerator {
// Check if cursor column type is valid for ClickHouse ReplacingMergeTree
val cursor = tableSchema.getCursor().firstOrNull()
val cursorType = cursor?.let { finalSchema[it]?.type }
val useCursorAsVersion =
cursorType != null && isValidVersionColumn(cursor, cursorType)
val versionColumn =
if (cursorType?.isValidVersionColumnType() ?: false) {
if (useCursorAsVersion) {
"`$cursor`"
} else {
// Fallback to _airbyte_extracted_at if no cursor is specified or cursor

View File

@@ -4,6 +4,7 @@
package io.airbyte.integrations.destination.clickhouse.client
import io.airbyte.cdk.load.table.CDC_CURSOR_COLUMN
import io.airbyte.integrations.destination.clickhouse.client.ClickhouseSqlTypes.VALID_VERSION_COLUMN_TYPES
object ClickhouseSqlTypes {
@@ -23,4 +24,9 @@ object ClickhouseSqlTypes {
)
}
fun String.isValidVersionColumnType() = VALID_VERSION_COLUMN_TYPES.contains(this)
// Warning: if any munging changes the name of the CDC column name this will break.
// Currently, that is not the case.
fun isValidVersionColumn(name: String, type: String) =
// CDC cursors cannot be used as a version column since they are null
// during the initial CDC snapshot.
name != CDC_CURSOR_COLUMN && VALID_VERSION_COLUMN_TYPES.contains(type)

View File

@@ -1,62 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.clickhouse.config
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.data.Transformations.Companion.toAlphanumericAndUnderscore
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.ColumnNameGenerator
import io.airbyte.cdk.load.table.FinalTableNameGenerator
import io.airbyte.integrations.destination.clickhouse.spec.ClickhouseConfiguration
import jakarta.inject.Singleton
import java.util.Locale
import java.util.UUID
@Singleton
class ClickhouseFinalTableNameGenerator(private val config: ClickhouseConfiguration) :
FinalTableNameGenerator {
override fun getTableName(streamDescriptor: DestinationStream.Descriptor) =
TableName(
namespace =
(streamDescriptor.namespace ?: config.resolvedDatabase)
.toClickHouseCompatibleName(),
name = streamDescriptor.name.toClickHouseCompatibleName(),
)
}
@Singleton
class ClickhouseColumnNameGenerator : ColumnNameGenerator {
override fun getColumnName(column: String): ColumnNameGenerator.ColumnName {
return ColumnNameGenerator.ColumnName(
column.toClickHouseCompatibleName(),
column.lowercase(Locale.getDefault()).toClickHouseCompatibleName(),
)
}
}
/**
* Transforms a string to be compatible with ClickHouse table and column names.
*
* @return The transformed string suitable for ClickHouse identifiers.
*/
fun String.toClickHouseCompatibleName(): String {
// 1. Replace any character that is not a letter,
// a digit (0-9), or an underscore (_) with a single underscore.
var transformed = toAlphanumericAndUnderscore(this)
// 2. Ensure the identifier does not start with a digit.
// If it starts with a digit, prepend an underscore.
if (transformed.isNotEmpty() && transformed[0].isDigit()) {
transformed = "_$transformed"
}
// 3.Do not allow empty strings.
if (transformed.isEmpty()) {
return "default_name_${UUID.randomUUID()}" // A fallback name if the input results in an
// empty string
}
return transformed
}

View File

@@ -0,0 +1,33 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.clickhouse.schema
import io.airbyte.cdk.load.data.Transformations.Companion.toAlphanumericAndUnderscore
import java.util.UUID
/**
* Transforms a string to be compatible with ClickHouse table and column names.
*
* @return The transformed string suitable for ClickHouse identifiers.
*/
fun String.toClickHouseCompatibleName(): String {
// 1. Replace any character that is not a letter,
// a digit (0-9), or an underscore (_) with a single underscore.
var transformed = toAlphanumericAndUnderscore(this)
// 2.Do not allow empty strings.
if (transformed.isEmpty()) {
return "default_name_${UUID.randomUUID()}" // A fallback name if the input results in an
// empty string
}
// 3. Ensure the identifier does not start with a digit.
// If it starts with a digit, prepend an underscore.
if (transformed[0].isDigit()) {
transformed = "_$transformed"
}
return transformed
}

View File

@@ -29,8 +29,7 @@ import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.integrations.destination.clickhouse.client.ClickhouseSqlTypes
import io.airbyte.integrations.destination.clickhouse.client.isValidVersionColumnType
import io.airbyte.integrations.destination.clickhouse.config.toClickHouseCompatibleName
import io.airbyte.integrations.destination.clickhouse.client.isValidVersionColumn
import io.airbyte.integrations.destination.clickhouse.spec.ClickhouseConfiguration
import jakarta.inject.Singleton
@@ -100,7 +99,7 @@ class ClickhouseTableSchemaMapper(
if (cursor != null) {
// Check if the cursor column type is valid for ClickHouse ReplacingMergeTree
val cursorColumnType = tableSchema.columnSchema.finalSchema[cursor]!!.type
if (cursorColumnType.isValidVersionColumnType()) {
if (isValidVersionColumn(cursor, cursorColumnType)) {
// Cursor column is valid, use it as version column
add(cursor) // Make cursor column non-nullable too
}

View File

@@ -0,0 +1,77 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.clickhouse.component
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures.NEGATIVE_HIGH_PRECISION_FLOAT
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures.POSITIVE_HIGH_PRECISION_FLOAT
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures.SMALLEST_NEGATIVE_FLOAT32
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures.SMALLEST_NEGATIVE_FLOAT64
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures.SMALLEST_POSITIVE_FLOAT32
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures.SMALLEST_POSITIVE_FLOAT64
import io.airbyte.cdk.load.component.DataCoercionSuite
import io.airbyte.cdk.load.component.TableOperationsClient
import io.airbyte.cdk.load.component.TestTableOperationsClient
import io.airbyte.cdk.load.component.toArgs
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.dataflow.transform.ValueCoercer
import io.airbyte.cdk.load.schema.TableSchemaFactory
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Reason
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.MethodSource
@MicronautTest(environments = ["component"], resolveParameters = false)
class ClickhouseDataCoercionTest(
override val coercer: ValueCoercer,
override val opsClient: TableOperationsClient,
override val testClient: TestTableOperationsClient,
override val schemaFactory: TableSchemaFactory,
) : DataCoercionSuite {
@ParameterizedTest
// We use clickhouse's Int64 type for integers
@MethodSource("io.airbyte.cdk.load.component.DataCoercionIntegerFixtures#int64")
override fun `handle integer values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) {
super.`handle integer values`(inputValue, expectedValue, expectedChangeReason)
}
@ParameterizedTest
@MethodSource(
"io.airbyte.integrations.destination.clickhouse.component.ClickhouseDataCoercionTest#numbers"
)
override fun `handle number values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) {
super.`handle number values`(inputValue, expectedValue, expectedChangeReason)
}
companion object {
/**
* destination-clickhouse doesn't set a change reason when truncating high-precision numbers
* (https://github.com/airbytehq/airbyte-internal-issues/issues/15401)
*/
@JvmStatic
fun numbers() =
DataCoercionNumberFixtures.numeric38_9
.map {
when (it.name) {
POSITIVE_HIGH_PRECISION_FLOAT,
NEGATIVE_HIGH_PRECISION_FLOAT,
SMALLEST_POSITIVE_FLOAT32,
SMALLEST_NEGATIVE_FLOAT32,
SMALLEST_POSITIVE_FLOAT64,
SMALLEST_NEGATIVE_FLOAT64 -> it.copy(changeReason = null)
else -> it
}
}
.toArgs()
}
}

View File

@@ -22,7 +22,7 @@ class ClickhouseTableSchemaEvolutionTest(
override val client: TableSchemaEvolutionClient,
override val opsClient: TableOperationsClient,
override val testClient: TestTableOperationsClient,
override val schemaFactory: TableSchemaFactory
override val schemaFactory: TableSchemaFactory,
) : TableSchemaEvolutionSuite {
private val allTypesTableSchema =
TableSchema(

View File

@@ -16,7 +16,7 @@ import io.airbyte.cdk.load.data.TimestampWithTimezoneValue
import io.airbyte.cdk.load.data.TimestampWithoutTimezoneValue
import io.airbyte.cdk.load.test.util.ExpectedRecordMapper
import io.airbyte.cdk.load.test.util.OutputRecord
import io.airbyte.integrations.destination.clickhouse.config.toClickHouseCompatibleName
import io.airbyte.integrations.destination.clickhouse.schema.toClickHouseCompatibleName
import java.math.RoundingMode
import java.time.LocalTime
import java.time.ZoneOffset

View File

@@ -30,8 +30,8 @@ import io.airbyte.cdk.load.write.UnknownTypesBehavior
import io.airbyte.integrations.destination.clickhouse.ClickhouseConfigUpdater
import io.airbyte.integrations.destination.clickhouse.ClickhouseContainerHelper
import io.airbyte.integrations.destination.clickhouse.Utils
import io.airbyte.integrations.destination.clickhouse.config.toClickHouseCompatibleName
import io.airbyte.integrations.destination.clickhouse.fixtures.ClickhouseExpectedRecordMapper
import io.airbyte.integrations.destination.clickhouse.schema.toClickHouseCompatibleName
import io.airbyte.integrations.destination.clickhouse.spec.ClickhouseConfiguration
import io.airbyte.integrations.destination.clickhouse.spec.ClickhouseConfigurationFactory
import io.airbyte.integrations.destination.clickhouse.spec.ClickhouseSpecificationOss

View File

@@ -21,7 +21,6 @@ import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.integrations.destination.clickhouse.config.ClickhouseFinalTableNameGenerator
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.coVerifyOrder
@@ -39,8 +38,6 @@ class ClickhouseAirbyteClientTest {
// Mocks
private val client: ClickHouseClientRaw = mockk(relaxed = true)
private val clickhouseSqlGenerator: ClickhouseSqlGenerator = mockk(relaxed = true)
private val clickhouseFinalTableNameGenerator: ClickhouseFinalTableNameGenerator =
mockk(relaxed = true)
private val tempTableNameGenerator: TempTableNameGenerator = mockk(relaxed = true)
// Client
@@ -105,7 +102,6 @@ class ClickhouseAirbyteClientTest {
alterTableStatement
coEvery { clickhouseAirbyteClient.execute(alterTableStatement) } returns
mockk(relaxed = true)
every { clickhouseFinalTableNameGenerator.getTableName(any()) } returns mockTableName
mockCHSchemaWithAirbyteColumns()
@@ -172,7 +168,6 @@ class ClickhouseAirbyteClientTest {
coEvery { clickhouseAirbyteClient.execute(any()) } returns mockk(relaxed = true)
every { tempTableNameGenerator.generate(any()) } returns tempTableName
every { clickhouseFinalTableNameGenerator.getTableName(any()) } returns finalTableName
mockCHSchemaWithAirbyteColumns()
@@ -226,8 +221,6 @@ class ClickhouseAirbyteClientTest {
fun `test ensure schema matches fails if no airbyte columns`() = runTest {
val finalTableName = TableName("fin", "al")
every { clickhouseFinalTableNameGenerator.getTableName(any()) } returns finalTableName
val columnMapping = ColumnNameMapping(mapOf())
val stream =
mockk<DestinationStream> {

View File

@@ -2,13 +2,13 @@
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.clickhouse.config
package io.airbyte.integrations.destination.clickhouse.schema
import java.util.UUID
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test
class ClickhouseNameGeneratorTest {
class ClickhouseNamingUtilsTest {
@Test
fun `toClickHouseCompatibleName replaces special characters with underscores`() {
Assertions.assertEquals("hello_world", "hello world".toClickHouseCompatibleName())

View File

@@ -6,7 +6,7 @@ data:
connectorSubtype: database
connectorType: destination
definitionId: 25c5221d-dce2-4163-ade9-739ef790f503
dockerImageTag: 3.0.5-rc.1
dockerImageTag: 3.0.5
dockerRepository: airbyte/destination-postgres
documentationUrl: https://docs.airbyte.com/integrations/destinations/postgres
githubIssueLabel: destination-postgres
@@ -22,7 +22,7 @@ data:
enabled: true
releases:
rolloutConfiguration:
enableProgressiveRollout: true
enableProgressiveRollout: false
breakingChanges:
3.0.0:
message: >

View File

@@ -4,12 +4,16 @@
package io.airbyte.integrations.destination.postgres.client
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
import io.airbyte.cdk.ConfigErrorException
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.component.ColumnChangeset
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.component.TableColumns
import io.airbyte.cdk.load.component.TableOperationsClient
import io.airbyte.cdk.load.component.TableSchema
import io.airbyte.cdk.load.component.TableSchemaEvolutionClient
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAMES
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.ColumnNameMapping
@@ -26,6 +30,11 @@ import javax.sql.DataSource
private val log = KotlinLogging.logger {}
@Singleton
@SuppressFBWarnings(
value = ["SQL_NONCONSTANT_STRING_PASSED_TO_EXECUTE"],
justification =
"There is little chance of SQL injection. There is also little need for statement reuse. The basic statement is more readable than the prepared statement."
)
class PostgresAirbyteClient(
private val dataSource: DataSource,
private val sqlGenerator: PostgresDirectLoadSqlGenerator,
@@ -53,6 +62,29 @@ class PostgresAirbyteClient(
null
}
override suspend fun namespaceExists(namespace: String): Boolean {
return executeQuery(
"""
SELECT EXISTS(
SELECT 1 FROM information_schema.schemata
WHERE schema_name = '$namespace'
)
"""
) { rs -> rs.next() && rs.getBoolean(1) }
}
override suspend fun tableExists(table: TableName): Boolean {
return executeQuery(
"""
SELECT EXISTS(
SELECT 1 FROM information_schema.tables
WHERE table_schema = '${table.namespace}'
AND table_name = '${table.name}'
)
"""
) { rs -> rs.next() && rs.getBoolean(1) }
}
override suspend fun createNamespace(namespace: String) {
try {
execute(sqlGenerator.createNamespace(namespace))
@@ -171,14 +203,26 @@ class PostgresAirbyteClient(
}
override suspend fun discoverSchema(tableName: TableName): TableSchema {
TODO("Not yet implemented")
val columnsInDb = getColumnsFromDbForDiscovery(tableName)
val hasAllAirbyteColumns = columnsInDb.keys.containsAll(COLUMN_NAMES)
if (!hasAllAirbyteColumns) {
val message =
"The target table ($tableName) already exists in the destination, but does not contain Airbyte's internal columns. Airbyte can only sync to Airbyte-controlled tables. To fix this error, you must either delete the target table or add a prefix in the connection configuration in order to sync to a separate table in the destination."
log.error { message }
throw ConfigErrorException(message)
}
// Filter out Airbyte columns
val userColumns = columnsInDb.filterKeys { it !in COLUMN_NAMES }
return TableSchema(userColumns)
}
override fun computeSchema(
stream: DestinationStream,
columnNameMapping: ColumnNameMapping
): TableSchema {
TODO("Not yet implemented")
return TableSchema(stream.tableSchema.columnSchema.finalSchema)
}
override suspend fun applyChangeset(
@@ -188,9 +232,73 @@ class PostgresAirbyteClient(
expectedColumns: TableColumns,
columnChangeset: ColumnChangeset
) {
TODO("Not yet implemented")
if (
columnChangeset.columnsToAdd.isNotEmpty() ||
columnChangeset.columnsToDrop.isNotEmpty() ||
columnChangeset.columnsToChange.isNotEmpty()
) {
log.info { "Summary of the table alterations:" }
log.info { "Added columns: ${columnChangeset.columnsToAdd}" }
log.info { "Deleted columns: ${columnChangeset.columnsToDrop}" }
log.info { "Modified columns: ${columnChangeset.columnsToChange}" }
// Convert from TableColumns format to Column format
val columnsToAdd =
columnChangeset.columnsToAdd
.map { (name, type) -> Column(name, type.type, type.nullable) }
.toSet()
val columnsToRemove =
columnChangeset.columnsToDrop
.map { (name, type) -> Column(name, type.type, type.nullable) }
.toSet()
val columnsToModify =
columnChangeset.columnsToChange
.map { (name, change) ->
Column(name, change.newType.type, change.newType.nullable)
}
.toSet()
val columnsInDb =
(columnChangeset.columnsToRetain +
columnChangeset.columnsToDrop +
columnChangeset.columnsToChange.mapValues { it.value.originalType })
.map { (name, type) -> Column(name, type.type, type.nullable) }
.toSet()
execute(
sqlGenerator.matchSchemas(
tableName = tableName,
columnsToAdd = columnsToAdd,
columnsToRemove = columnsToRemove,
columnsToModify = columnsToModify,
columnsInDb = columnsInDb,
recreatePrimaryKeyIndex = false,
primaryKeyColumnNames = emptyList(),
recreateCursorIndex = false,
cursorColumnName = null,
)
)
}
}
/**
* Gets columns from the database including their types for schema discovery. Unlike
* [getColumnsFromDb], this returns all columns including Airbyte metadata columns.
*/
private fun getColumnsFromDbForDiscovery(tableName: TableName): Map<String, ColumnType> =
executeQuery(sqlGenerator.getTableSchema(tableName)) { rs ->
val columnsInDb: MutableMap<String, ColumnType> = mutableMapOf()
while (rs.next()) {
val columnName = rs.getString(COLUMN_NAME_COLUMN)
val dataType = rs.getString("data_type")
// PostgreSQL's information_schema always returns 'YES' or 'NO' for is_nullable
val isNullable = rs.getString("is_nullable") == "YES"
columnsInDb[columnName] = ColumnType(normalizePostgresType(dataType), isNullable)
}
columnsInDb
}
/**
* Checks if the primary key index matches the current stream configuration. If the primary keys
* have changed (detected by comparing columns in the index), then this will return true,

View File

@@ -531,7 +531,7 @@ class PostgresDirectLoadSqlGenerator(
fun getTableSchema(tableName: TableName): String =
"""
SELECT column_name, data_type
SELECT column_name, data_type, is_nullable
FROM information_schema.columns
WHERE table_schema = '${tableName.namespace}'
AND table_name = '${tableName.name}';

View File

@@ -49,6 +49,7 @@ class PostgresWriter(
override fun createStreamLoader(stream: DestinationStream): StreamLoader {
val initialStatus = initialStatuses[stream]!!
val realTableName = stream.tableSchema.tableNames.finalTableName!!
val tempTableName = tempTableNameGenerator.generate(realTableName)
val columnNameMapping =
ColumnNameMapping(stream.tableSchema.columnSchema.inputToFinalColumnNames)

View File

@@ -0,0 +1,45 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.postgres.component
import io.airbyte.cdk.load.util.Jsons
import io.airbyte.integrations.destination.postgres.PostgresConfigUpdater
import io.airbyte.integrations.destination.postgres.PostgresContainerHelper
import io.airbyte.integrations.destination.postgres.spec.PostgresConfiguration
import io.airbyte.integrations.destination.postgres.spec.PostgresConfigurationFactory
import io.airbyte.integrations.destination.postgres.spec.PostgresSpecificationOss
import io.micronaut.context.annotation.Factory
import io.micronaut.context.annotation.Primary
import io.micronaut.context.annotation.Requires
import jakarta.inject.Singleton
@Requires(env = ["component"])
@Factory
class PostgresComponentTestConfigFactory {
@Singleton
@Primary
fun config(): PostgresConfiguration {
// Start the postgres container
PostgresContainerHelper.start()
// Create a minimal config JSON and update it with container details
val configJson =
"""
{
"host": "replace_me_host",
"port": "replace_me_port",
"database": "replace_me_database",
"schema": "public",
"username": "replace_me_username",
"password": "replace_me_password",
"ssl": false
}
"""
val updatedConfig = PostgresConfigUpdater().update(configJson)
val spec = Jsons.readValue(updatedConfig, PostgresSpecificationOss::class.java)
return PostgresConfigurationFactory().makeWithoutExceptionHandling(spec)
}
}

View File

@@ -0,0 +1,36 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.postgres.component
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.component.TableOperationsFixtures
import io.airbyte.cdk.load.component.TableSchema
object PostgresComponentTestFixtures {
// PostgreSQL uses lowercase column names by default (no transformation needed)
val testMapping = TableOperationsFixtures.TEST_MAPPING
val idAndTestMapping = TableOperationsFixtures.ID_AND_TEST_MAPPING
val idTestWithCdcMapping = TableOperationsFixtures.ID_TEST_WITH_CDC_MAPPING
val allTypesTableSchema =
TableSchema(
mapOf(
"string" to ColumnType("varchar", true),
"boolean" to ColumnType("boolean", true),
"integer" to ColumnType("bigint", true),
"number" to ColumnType("decimal", true),
"date" to ColumnType("date", true),
"timestamp_tz" to ColumnType("timestamp with time zone", true),
"timestamp_ntz" to ColumnType("timestamp", true),
"time_tz" to ColumnType("time with time zone", true),
"time_ntz" to ColumnType("time", true),
"array" to ColumnType("jsonb", true),
"object" to ColumnType("jsonb", true),
"unknown" to ColumnType("jsonb", true),
)
)
val allTypesColumnNameMapping = TableOperationsFixtures.ALL_TYPES_MAPPING
}

View File

@@ -0,0 +1,92 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.postgres.component
import io.airbyte.cdk.load.component.TableOperationsFixtures
import io.airbyte.cdk.load.component.TableOperationsSuite
import io.airbyte.cdk.load.schema.TableSchemaFactory
import io.airbyte.integrations.destination.postgres.client.PostgresAirbyteClient
import io.airbyte.integrations.destination.postgres.component.PostgresComponentTestFixtures.idTestWithCdcMapping
import io.airbyte.integrations.destination.postgres.component.PostgresComponentTestFixtures.testMapping
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import jakarta.inject.Inject
import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Test
@MicronautTest(environments = ["component"])
class PostgresTableOperationsTest(
override val client: PostgresAirbyteClient,
override val testClient: PostgresTestTableOperationsClient,
) : TableOperationsSuite {
@Inject override lateinit var schemaFactory: TableSchemaFactory
@Test
override fun `connect to database`() {
super.`connect to database`()
}
@Test
override fun `create and drop namespaces`() {
super.`create and drop namespaces`()
}
@Test
override fun `create and drop tables`() {
super.`create and drop tables`()
}
@Test
override fun `insert records`() {
super.`insert records`(
inputRecords = TableOperationsFixtures.SINGLE_TEST_RECORD_INPUT,
expectedRecords = TableOperationsFixtures.SINGLE_TEST_RECORD_EXPECTED,
columnNameMapping = testMapping,
)
}
@Test
override fun `count table rows`() {
super.`count table rows`(columnNameMapping = testMapping)
}
@Test
override fun `overwrite tables`() {
super.`overwrite tables`(
sourceInputRecords = TableOperationsFixtures.OVERWRITE_SOURCE_RECORDS,
targetInputRecords = TableOperationsFixtures.OVERWRITE_TARGET_RECORDS,
expectedRecords = TableOperationsFixtures.OVERWRITE_EXPECTED_RECORDS,
columnNameMapping = testMapping,
)
}
@Test
override fun `copy tables`() {
super.`copy tables`(
sourceInputRecords = TableOperationsFixtures.OVERWRITE_SOURCE_RECORDS,
targetInputRecords = TableOperationsFixtures.OVERWRITE_TARGET_RECORDS,
expectedRecords = TableOperationsFixtures.COPY_EXPECTED_RECORDS,
columnNameMapping = testMapping,
)
}
@Test
override fun `get generation id`() {
super.`get generation id`(columnNameMapping = testMapping)
}
// TODO: Re-enable when CDK TableOperationsSuite is fixed to use ID_AND_TEST_SCHEMA for target
// table instead of TEST_INTEGER_SCHEMA (the Dedupe mode requires the id column as primary key)
@Disabled("CDK TableOperationsSuite bug: target table schema missing 'id' column for Dedupe")
@Test
override fun `upsert tables`() {
super.`upsert tables`(
sourceInputRecords = TableOperationsFixtures.UPSERT_SOURCE_RECORDS,
targetInputRecords = TableOperationsFixtures.UPSERT_TARGET_RECORDS,
expectedRecords = TableOperationsFixtures.UPSERT_EXPECTED_RECORDS,
columnNameMapping = idTestWithCdcMapping,
)
}
}

View File

@@ -0,0 +1,111 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.postgres.component
import io.airbyte.cdk.load.command.ImportType
import io.airbyte.cdk.load.component.TableSchemaEvolutionFixtures
import io.airbyte.cdk.load.component.TableSchemaEvolutionSuite
import io.airbyte.cdk.load.schema.TableSchemaFactory
import io.airbyte.integrations.destination.postgres.client.PostgresAirbyteClient
import io.airbyte.integrations.destination.postgres.component.PostgresComponentTestFixtures.allTypesColumnNameMapping
import io.airbyte.integrations.destination.postgres.component.PostgresComponentTestFixtures.allTypesTableSchema
import io.airbyte.integrations.destination.postgres.component.PostgresComponentTestFixtures.idAndTestMapping
import io.airbyte.integrations.destination.postgres.component.PostgresComponentTestFixtures.testMapping
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import org.junit.jupiter.api.Test
@MicronautTest(environments = ["component"], resolveParameters = false)
class PostgresTableSchemaEvolutionTest(
override val client: PostgresAirbyteClient,
override val opsClient: PostgresAirbyteClient,
override val testClient: PostgresTestTableOperationsClient,
override val schemaFactory: TableSchemaFactory,
) : TableSchemaEvolutionSuite {
@Test
fun `discover recognizes all data types`() {
super.`discover recognizes all data types`(allTypesTableSchema, allTypesColumnNameMapping)
}
@Test
fun `computeSchema handles all data types`() {
super.`computeSchema handles all data types`(allTypesTableSchema, allTypesColumnNameMapping)
}
@Test
override fun `noop diff`() {
super.`noop diff`(testMapping)
}
@Test
override fun `changeset is correct when adding a column`() {
super.`changeset is correct when adding a column`(testMapping, idAndTestMapping)
}
@Test
override fun `changeset is correct when dropping a column`() {
super.`changeset is correct when dropping a column`(idAndTestMapping, testMapping)
}
@Test
override fun `changeset is correct when changing a column's type`() {
super.`changeset is correct when changing a column's type`(testMapping)
}
@Test
override fun `apply changeset - handle sync mode append`() {
super.`apply changeset - handle sync mode append`()
}
@Test
override fun `apply changeset - handle changing sync mode from append to dedup`() {
super.`apply changeset - handle changing sync mode from append to dedup`()
}
@Test
override fun `apply changeset - handle changing sync mode from dedup to append`() {
super.`apply changeset - handle changing sync mode from dedup to append`()
}
@Test
override fun `apply changeset - handle sync mode dedup`() {
super.`apply changeset - handle sync mode dedup`()
}
override fun `apply changeset`(
initialStreamImportType: ImportType,
modifiedStreamImportType: ImportType,
) {
super.`apply changeset`(
initialColumnNameMapping =
TableSchemaEvolutionFixtures.APPLY_CHANGESET_INITIAL_COLUMN_MAPPING,
modifiedColumnNameMapping =
TableSchemaEvolutionFixtures.APPLY_CHANGESET_MODIFIED_COLUMN_MAPPING,
TableSchemaEvolutionFixtures.APPLY_CHANGESET_EXPECTED_EXTRACTED_AT,
initialStreamImportType,
modifiedStreamImportType,
)
}
@Test
override fun `change from string type to unknown type`() {
super.`change from string type to unknown type`(
idAndTestMapping,
idAndTestMapping,
TableSchemaEvolutionFixtures.STRING_TO_UNKNOWN_TYPE_INPUT_RECORDS,
TableSchemaEvolutionFixtures.STRING_TO_UNKNOWN_TYPE_EXPECTED_RECORDS,
)
}
@Test
override fun `change from unknown type to string type`() {
super.`change from unknown type to string type`(
idAndTestMapping,
idAndTestMapping,
TableSchemaEvolutionFixtures.UNKNOWN_TO_STRING_TYPE_INPUT_RECORDS,
TableSchemaEvolutionFixtures.UNKNOWN_TO_STRING_TYPE_EXPECTED_RECORDS,
)
}
}

View File

@@ -0,0 +1,257 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.postgres.component
import io.airbyte.cdk.load.component.TestTableOperationsClient
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.util.Jsons
import io.airbyte.integrations.destination.postgres.client.PostgresAirbyteClient
import io.micronaut.context.annotation.Requires
import jakarta.inject.Singleton
import java.time.OffsetDateTime
import java.time.ZoneOffset
import java.time.format.DateTimeFormatter
import javax.sql.DataSource
@Requires(env = ["component"])
@Singleton
class PostgresTestTableOperationsClient(
private val dataSource: DataSource,
private val client: PostgresAirbyteClient,
) : TestTableOperationsClient {
override suspend fun ping() {
dataSource.connection.use { connection ->
connection.createStatement().use { statement -> statement.executeQuery("SELECT 1") }
}
}
override suspend fun dropNamespace(namespace: String) {
dataSource.connection.use { connection ->
connection.createStatement().use { statement ->
statement.execute("DROP SCHEMA IF EXISTS \"$namespace\" CASCADE")
}
}
}
override suspend fun insertRecords(table: TableName, records: List<Map<String, AirbyteValue>>) {
if (records.isEmpty()) return
// Get column types from database to handle jsonb columns properly
val columnTypes = getColumnTypes(table)
// Get all unique columns from ALL records to handle sparse data (e.g., CDC deletion column)
val columns = records.flatMap { it.keys }.distinct().toList()
val columnNames = columns.joinToString(", ") { "\"$it\"" }
val placeholders = columns.indices.joinToString(", ") { "?" }
val sql =
"""
INSERT INTO "${table.namespace}"."${table.name}" ($columnNames)
VALUES ($placeholders)
"""
dataSource.connection.use { connection ->
connection.prepareStatement(sql).use { statement ->
for (record in records) {
columns.forEachIndexed { index, column ->
val value = record[column]
val columnType = columnTypes[column]
setParameterValue(statement, index + 1, value, columnType)
}
statement.addBatch()
}
statement.executeBatch()
}
}
}
private fun getColumnTypes(table: TableName): Map<String, String> {
val columnTypes = mutableMapOf<String, String>()
dataSource.connection.use { connection ->
connection.createStatement().use { statement ->
statement
.executeQuery(
"""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = '${table.namespace}'
AND table_name = '${table.name}'
"""
)
.use { resultSet ->
while (resultSet.next()) {
columnTypes[resultSet.getString("column_name")] =
resultSet.getString("data_type")
}
}
}
}
return columnTypes
}
private fun setParameterValue(
statement: java.sql.PreparedStatement,
index: Int,
value: AirbyteValue?,
columnType: String?
) {
// If column is jsonb, serialize any value as JSON
if (columnType == "jsonb") {
if (value == null || value is io.airbyte.cdk.load.data.NullValue) {
statement.setNull(index, java.sql.Types.OTHER)
} else {
val pgObject = org.postgresql.util.PGobject()
pgObject.type = "jsonb"
pgObject.value = serializeToJson(value)
statement.setObject(index, pgObject)
}
return
}
when (value) {
null,
is io.airbyte.cdk.load.data.NullValue -> statement.setNull(index, java.sql.Types.NULL)
is io.airbyte.cdk.load.data.StringValue -> statement.setString(index, value.value)
is io.airbyte.cdk.load.data.IntegerValue ->
statement.setLong(index, value.value.toLong())
is io.airbyte.cdk.load.data.NumberValue -> statement.setBigDecimal(index, value.value)
is io.airbyte.cdk.load.data.BooleanValue -> statement.setBoolean(index, value.value)
is io.airbyte.cdk.load.data.TimestampWithTimezoneValue -> {
val offsetDateTime = OffsetDateTime.parse(value.value.toString())
statement.setObject(index, offsetDateTime)
}
is io.airbyte.cdk.load.data.TimestampWithoutTimezoneValue -> {
val localDateTime = java.time.LocalDateTime.parse(value.value.toString())
statement.setObject(index, localDateTime)
}
is io.airbyte.cdk.load.data.DateValue -> {
val localDate = java.time.LocalDate.parse(value.value.toString())
statement.setObject(index, localDate)
}
is io.airbyte.cdk.load.data.TimeWithTimezoneValue -> {
statement.setString(index, value.value.toString())
}
is io.airbyte.cdk.load.data.TimeWithoutTimezoneValue -> {
val localTime = java.time.LocalTime.parse(value.value.toString())
statement.setObject(index, localTime)
}
is io.airbyte.cdk.load.data.ObjectValue -> {
val pgObject = org.postgresql.util.PGobject()
pgObject.type = "jsonb"
pgObject.value = Jsons.writeValueAsString(value.values)
statement.setObject(index, pgObject)
}
is io.airbyte.cdk.load.data.ArrayValue -> {
val pgObject = org.postgresql.util.PGobject()
pgObject.type = "jsonb"
pgObject.value = Jsons.writeValueAsString(value.values)
statement.setObject(index, pgObject)
}
else -> {
// For unknown types, try to serialize as string
statement.setString(index, value.toString())
}
}
}
private fun serializeToJson(value: AirbyteValue): String {
return when (value) {
is io.airbyte.cdk.load.data.StringValue -> Jsons.writeValueAsString(value.value)
is io.airbyte.cdk.load.data.IntegerValue -> value.value.toString()
is io.airbyte.cdk.load.data.NumberValue -> value.value.toString()
is io.airbyte.cdk.load.data.BooleanValue -> value.value.toString()
is io.airbyte.cdk.load.data.ObjectValue -> Jsons.writeValueAsString(value.values)
is io.airbyte.cdk.load.data.ArrayValue -> Jsons.writeValueAsString(value.values)
is io.airbyte.cdk.load.data.NullValue -> "null"
else -> Jsons.writeValueAsString(value.toString())
}
}
override suspend fun readTable(table: TableName): List<Map<String, Any>> {
dataSource.connection.use { connection ->
connection.createStatement().use { statement ->
statement
.executeQuery("""SELECT * FROM "${table.namespace}"."${table.name}"""")
.use { resultSet ->
val metaData = resultSet.metaData
val columnCount = metaData.columnCount
val result = mutableListOf<Map<String, Any>>()
while (resultSet.next()) {
val row = mutableMapOf<String, Any>()
for (i in 1..columnCount) {
val columnName = metaData.getColumnName(i)
val columnType = metaData.getColumnTypeName(i)
when (columnType.lowercase()) {
"timestamptz" -> {
val value =
resultSet.getObject(i, OffsetDateTime::class.java)
if (value != null) {
val formattedTimestamp =
DateTimeFormatter.ISO_OFFSET_DATE_TIME.format(
value.withOffsetSameInstant(ZoneOffset.UTC)
)
row[columnName] = formattedTimestamp
}
}
"timestamp" -> {
val value = resultSet.getTimestamp(i)
if (value != null) {
val localDateTime = value.toLocalDateTime()
row[columnName] =
DateTimeFormatter.ISO_LOCAL_DATE_TIME.format(
localDateTime
)
}
}
"jsonb",
"json" -> {
val stringValue: String? = resultSet.getString(i)
if (stringValue != null) {
val parsedValue =
Jsons.readValue(stringValue, Any::class.java)
val actualValue =
when (parsedValue) {
is Int -> parsedValue.toLong()
else -> parsedValue
}
row[columnName] = actualValue
}
}
else -> {
val value = resultSet.getObject(i)
if (value != null) {
// For varchar columns that may contain JSON (from
// schema evolution),
// normalize the JSON to compact format for comparison
if (
value is String &&
(value.startsWith("{") || value.startsWith("["))
) {
try {
val parsed =
Jsons.readValue(value, Any::class.java)
row[columnName] =
Jsons.writeValueAsString(parsed)
} catch (_: Exception) {
row[columnName] = value
}
} else {
row[columnName] = value
}
}
}
}
}
result.add(row)
}
return result
}
}
}
}
}

View File

@@ -267,7 +267,7 @@ class PostgresRawDataDumper(
.lowercase()
.toPostgresCompatibleName()
val fullyQualifiedTableName = "$rawNamespace.$rawName"
val fullyQualifiedTableName = "\"$rawNamespace\".\"$rawName\""
// Check if table exists first
val tableExistsQuery =
@@ -302,6 +302,26 @@ class PostgresRawDataDumper(
false
}
// Build the column name mapping from original names to transformed names
// We use the stream schema to get the original field names, then transform them
// using the postgres name transformation logic
val finalToInputColumnNames = mutableMapOf<String, String>()
if (stream.schema is ObjectType) {
val objectSchema = stream.schema as ObjectType
for (fieldName in objectSchema.properties.keys) {
val transformedName = fieldName.toPostgresCompatibleName()
// Map transformed name back to original name
finalToInputColumnNames[transformedName] = fieldName
}
}
// Also check if inputToFinalColumnNames mapping is available
val inputToFinalColumnNames =
stream.tableSchema.columnSchema.inputToFinalColumnNames
// Add entries from the existing mapping (in case it was populated)
for ((input, final) in inputToFinalColumnNames) {
finalToInputColumnNames[final] = input
}
while (resultSet.next()) {
val rawData =
if (hasDataColumn) {
@@ -313,8 +333,22 @@ class PostgresRawDataDumper(
else -> dataObject?.toString() ?: "{}"
}
// Parse JSON to AirbyteValue, then coerce it to match the schema
dataJson?.deserializeToNode()?.toAirbyteValue() ?: NullValue
// Parse JSON to AirbyteValue, then map column names back to originals
val parsedValue =
dataJson?.deserializeToNode()?.toAirbyteValue() ?: NullValue
// If the parsed value is an ObjectValue, map the column names back
if (parsedValue is ObjectValue) {
val mappedProperties = linkedMapOf<String, AirbyteValue>()
for ((key, value) in parsedValue.values) {
// Map final column name back to input column name if mapping
// exists
val originalKey = finalToInputColumnNames[key] ?: key
mappedProperties[originalKey] = value
}
ObjectValue(mappedProperties)
} else {
parsedValue
}
} else {
// Typed table mode: read from individual columns and reconstruct the
// object
@@ -333,10 +367,19 @@ class PostgresRawDataDumper(
for ((fieldName, fieldType) in objectSchema.properties) {
try {
// Map input field name to the transformed final column name
// First check the inputToFinalColumnNames mapping, then
// fall
// back to applying postgres transformation directly
val transformedColumnName =
inputToFinalColumnNames[fieldName]
?: fieldName.toPostgresCompatibleName()
// Try to find the actual column name (case-insensitive
// lookup)
val actualColumnName =
columnMap[fieldName.lowercase()] ?: fieldName
columnMap[transformedColumnName.lowercase()]
?: transformedColumnName
val columnValue = resultSet.getObject(actualColumnName)
properties[fieldName] =
when (columnValue) {

View File

@@ -1,3 +1,3 @@
testExecutionConcurrency=-1
cdkVersion=0.1.82
cdkVersion=0.1.91
JunitMethodExecutionTimeout=10m

View File

@@ -6,7 +6,7 @@ data:
connectorSubtype: database
connectorType: destination
definitionId: 424892c4-daac-4491-b35d-c6688ba547ba
dockerImageTag: 4.0.31
dockerImageTag: 4.0.32-rc.1
dockerRepository: airbyte/destination-snowflake
documentationUrl: https://docs.airbyte.com/integrations/destinations/snowflake
githubIssueLabel: destination-snowflake
@@ -31,6 +31,8 @@ data:
enabled: true
releaseStage: generally_available
releases:
rolloutConfiguration:
enableProgressiveRollout: true
breakingChanges:
2.0.0:
message: Remove GCS/S3 loading method support.

View File

@@ -11,15 +11,18 @@ import io.airbyte.cdk.load.check.CheckOperationV2
import io.airbyte.cdk.load.check.DestinationCheckerV2
import io.airbyte.cdk.load.config.DataChannelMedium
import io.airbyte.cdk.load.dataflow.config.AggregatePublishingConfig
import io.airbyte.cdk.load.orchestration.db.DefaultTempTableNameGenerator
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator
import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator
import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.cdk.output.OutputConsumer
import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory
import io.airbyte.integrations.destination.snowflake.spec.UsernamePasswordAuthConfiguration
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRawRecordFormatter
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter
import io.micronaut.context.annotation.Factory
import io.micronaut.context.annotation.Primary
import io.micronaut.context.annotation.Requires
@@ -204,6 +207,17 @@ class SnowflakeBeanFactory {
outputConsumer: OutputConsumer,
) = CheckOperationV2(destinationChecker, outputConsumer)
@Singleton
fun snowflakeRecordFormatter(
snowflakeConfiguration: SnowflakeConfiguration
): SnowflakeRecordFormatter {
return if (snowflakeConfiguration.legacyRawTablesOnly) {
SnowflakeRawRecordFormatter()
} else {
SnowflakeSchemaRecordFormatter()
}
}
@Singleton
fun aggregatePublishingConfig(dataChannelMedium: DataChannelMedium): AggregatePublishingConfig {
// NOT speed mode

View File

@@ -13,13 +13,17 @@ import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.schema.model.TableNames
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TableName
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter
import jakarta.inject.Singleton
import java.time.OffsetDateTime
import java.util.UUID
@@ -31,7 +35,7 @@ internal const val CHECK_COLUMN_NAME = "test_key"
class SnowflakeChecker(
private val snowflakeAirbyteClient: SnowflakeAirbyteClient,
private val snowflakeConfiguration: SnowflakeConfiguration,
private val snowflakeColumnUtils: SnowflakeColumnUtils,
private val columnManager: SnowflakeColumnManager,
) : DestinationCheckerV2 {
override fun check() {
@@ -46,11 +50,40 @@ class SnowflakeChecker(
Meta.AirbyteMetaFields.GENERATION_ID.fieldName to AirbyteValue.from(0),
CHECK_COLUMN_NAME.toSnowflakeCompatibleName() to AirbyteValue.from("test-value")
)
val outputSchema = snowflakeConfiguration.schema.toSnowflakeCompatibleName()
val outputSchema =
if (snowflakeConfiguration.legacyRawTablesOnly) {
snowflakeConfiguration.schema
} else {
snowflakeConfiguration.schema.toSnowflakeCompatibleName()
}
val tableName =
"_airbyte_connection_test_${
UUID.randomUUID().toString().replace("-".toRegex(), "")}".toSnowflakeCompatibleName()
val qualifiedTableName = TableName(namespace = outputSchema, name = tableName)
val tableSchema =
StreamTableSchema(
tableNames =
TableNames(
finalTableName = qualifiedTableName,
tempTableName = qualifiedTableName
),
columnSchema =
ColumnSchema(
inputToFinalColumnNames =
mapOf(
CHECK_COLUMN_NAME to CHECK_COLUMN_NAME.toSnowflakeCompatibleName()
),
finalSchema =
mapOf(
CHECK_COLUMN_NAME.toSnowflakeCompatibleName() to
io.airbyte.cdk.load.component.ColumnType("VARCHAR", false)
),
inputSchema =
mapOf(CHECK_COLUMN_NAME to FieldType(StringType, nullable = false))
),
importType = Append
)
val destinationStream =
DestinationStream(
unmappedNamespace = outputSchema,
@@ -63,7 +96,8 @@ class SnowflakeChecker(
generationId = 0L,
minimumGenerationId = 0L,
syncId = 0L,
namespaceMapper = NamespaceMapper()
namespaceMapper = NamespaceMapper(),
tableSchema = tableSchema
)
runBlocking {
try {
@@ -75,14 +109,14 @@ class SnowflakeChecker(
replace = true,
)
val columns = snowflakeAirbyteClient.describeTable(qualifiedTableName)
val snowflakeInsertBuffer =
SnowflakeInsertBuffer(
tableName = qualifiedTableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils,
columnSchema = tableSchema.columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = SnowflakeSchemaRecordFormatter(),
)
snowflakeInsertBuffer.accumulate(data)

View File

@@ -13,18 +13,16 @@ import io.airbyte.cdk.load.component.TableColumns
import io.airbyte.cdk.load.component.TableOperationsClient
import io.airbyte.cdk.load.component.TableSchema
import io.airbyte.cdk.load.component.TableSchemaEvolutionClient
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.util.deserializeToNode
import io.airbyte.integrations.destination.snowflake.db.ColumnDefinition
import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS
import io.airbyte.integrations.destination.snowflake.sql.NOT_NULL
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
import io.airbyte.integrations.destination.snowflake.sql.andLog
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
import io.github.oshai.kotlinlogging.KotlinLogging
import jakarta.inject.Singleton
import java.sql.ResultSet
@@ -41,13 +39,10 @@ private val log = KotlinLogging.logger {}
class SnowflakeAirbyteClient(
private val dataSource: DataSource,
private val sqlGenerator: SnowflakeDirectLoadSqlGenerator,
private val snowflakeColumnUtils: SnowflakeColumnUtils,
private val snowflakeConfiguration: SnowflakeConfiguration,
private val columnManager: SnowflakeColumnManager,
) : TableOperationsClient, TableSchemaEvolutionClient {
private val airbyteColumnNames =
snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet()
override suspend fun countTable(tableName: TableName): Long? =
try {
dataSource.connection.use { connection ->
@@ -126,7 +121,7 @@ class SnowflakeAirbyteClient(
columnNameMapping: ColumnNameMapping,
replace: Boolean
) {
execute(sqlGenerator.createTable(stream, tableName, columnNameMapping, replace))
execute(sqlGenerator.createTable(tableName, stream.tableSchema, replace))
execute(sqlGenerator.createSnowflakeStage(tableName))
}
@@ -163,7 +158,15 @@ class SnowflakeAirbyteClient(
sourceTableName: TableName,
targetTableName: TableName
) {
execute(sqlGenerator.copyTable(columnNameMapping, sourceTableName, targetTableName))
// Get all column names from the mapping (both meta columns and user columns)
val columnNames = buildSet {
// Add Airbyte meta columns (using uppercase constants)
addAll(columnManager.getMetaColumnNames())
// Add user columns from mapping
addAll(columnNameMapping.values)
}
execute(sqlGenerator.copyTable(columnNames, sourceTableName, targetTableName))
}
override suspend fun upsertTable(
@@ -172,9 +175,7 @@ class SnowflakeAirbyteClient(
sourceTableName: TableName,
targetTableName: TableName
) {
execute(
sqlGenerator.upsertTable(stream, columnNameMapping, sourceTableName, targetTableName)
)
execute(sqlGenerator.upsertTable(stream.tableSchema, sourceTableName, targetTableName))
}
override suspend fun dropTable(tableName: TableName) {
@@ -206,7 +207,7 @@ class SnowflakeAirbyteClient(
stream: DestinationStream,
columnNameMapping: ColumnNameMapping
): TableSchema {
return TableSchema(getColumnsFromStream(stream, columnNameMapping))
return TableSchema(stream.tableSchema.columnSchema.finalSchema)
}
override suspend fun applyChangeset(
@@ -253,7 +254,7 @@ class SnowflakeAirbyteClient(
val columnName = escapeJsonIdentifier(rs.getString("name"))
// Filter out airbyte columns
if (airbyteColumnNames.contains(columnName)) {
if (columnManager.getMetaColumnNames().contains(columnName)) {
continue
}
val dataType = rs.getString("type").takeWhile { char -> char != '(' }
@@ -271,49 +272,6 @@ class SnowflakeAirbyteClient(
}
}
internal fun getColumnsFromStream(
stream: DestinationStream,
columnNameMapping: ColumnNameMapping
): Map<String, ColumnType> =
snowflakeColumnUtils
.columnsAndTypes(stream.schema.asColumns(), columnNameMapping)
.filter { column -> column.columnName !in airbyteColumnNames }
.associate { column ->
// columnsAndTypes returns types as either `FOO` or `FOO NOT NULL`.
// so check for that suffix.
val nullable = !column.columnType.endsWith(NOT_NULL)
val type =
column.columnType
.takeWhile { char ->
// This is to remove any precision parts of the dialect type
char != '('
}
.removeSuffix(NOT_NULL)
.trim()
column.columnName to ColumnType(type, nullable)
}
internal fun generateSchemaChanges(
columnsInDb: Set<ColumnDefinition>,
columnsInStream: Set<ColumnDefinition>
): Triple<Set<ColumnDefinition>, Set<ColumnDefinition>, Set<ColumnDefinition>> {
val addedColumns =
columnsInStream.filter { it.name !in columnsInDb.map { col -> col.name } }.toSet()
val deletedColumns =
columnsInDb.filter { it.name !in columnsInStream.map { col -> col.name } }.toSet()
val commonColumns =
columnsInStream.filter { it.name in columnsInDb.map { col -> col.name } }.toSet()
val modifiedColumns =
commonColumns
.filter {
val dbType = columnsInDb.find { column -> it.name == column.name }?.type
it.type != dbType
}
.toSet()
return Triple(addedColumns, deletedColumns, modifiedColumns)
}
override suspend fun getGenerationId(tableName: TableName): Long =
try {
dataSource.connection.use { connection ->
@@ -326,7 +284,7 @@ class SnowflakeAirbyteClient(
* format. In order to make sure these strings will match any column names
* that we have formatted in-memory, re-apply the escaping.
*/
resultSet.getLong(snowflakeColumnUtils.getGenerationIdColumnName())
resultSet.getLong(columnManager.getGenerationIdColumnName())
} else {
log.warn {
"No generation ID found for table ${tableName.toPrettyString()}, returning 0"
@@ -351,8 +309,8 @@ class SnowflakeAirbyteClient(
execute(sqlGenerator.putInStage(tableName, tempFilePath))
}
fun copyFromStage(tableName: TableName, filename: String) {
execute(sqlGenerator.copyFromStage(tableName, filename))
fun copyFromStage(tableName: TableName, filename: String, columnNames: List<String>) {
execute(sqlGenerator.copyFromStage(tableName, filename, columnNames))
}
fun describeTable(tableName: TableName): LinkedHashMap<String, String> =

View File

@@ -4,47 +4,41 @@
package io.airbyte.integrations.destination.snowflake.dataflow
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.dataflow.aggregate.Aggregate
import io.airbyte.cdk.load.dataflow.aggregate.AggregateFactory
import io.airbyte.cdk.load.dataflow.aggregate.StoreKey
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.write.StreamStateStore
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
import io.micronaut.cache.annotation.CacheConfig
import io.micronaut.cache.annotation.Cacheable
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
import jakarta.inject.Singleton
@Singleton
@CacheConfig("table-columns")
// class has to be open to make the cache stuff work
open class SnowflakeAggregateFactory(
class SnowflakeAggregateFactory(
private val snowflakeClient: SnowflakeAirbyteClient,
private val streamStateStore: StreamStateStore<DirectLoadTableExecutionConfig>,
private val snowflakeConfiguration: SnowflakeConfiguration,
private val snowflakeColumnUtils: SnowflakeColumnUtils,
private val catalog: DestinationCatalog,
private val columnManager: SnowflakeColumnManager,
private val snowflakeRecordFormatter: SnowflakeRecordFormatter,
) : AggregateFactory {
override fun create(key: StoreKey): Aggregate {
val stream = catalog.getStream(key)
val tableName = streamStateStore.get(key)!!.tableName
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
columns = getTableColumns(tableName),
snowflakeClient = snowflakeClient,
snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils,
columnSchema = stream.tableSchema.columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
)
return SnowflakeAggregate(buffer = buffer)
}
// We assume that a table isn't getting altered _during_ a sync.
// This allows us to only SHOW COLUMNS once per table per sync,
// rather than refetching it on every aggregate.
@Cacheable
// function has to be open to make caching work
internal open fun getTableColumns(tableName: TableName) =
snowflakeClient.describeTable(tableName)
}

View File

@@ -1,13 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
/**
* Jdbc destination column definition representation
*
* @param name
* @param type
*/
data class ColumnDefinition(val name: String, val type: String)

View File

@@ -4,17 +4,17 @@
package io.airbyte.integrations.destination.snowflake.db
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.component.TableOperationsClient
import io.airbyte.cdk.load.orchestration.db.BaseDirectLoadInitialStatusGatherer
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator
import io.airbyte.cdk.load.table.BaseDirectLoadInitialStatusGatherer
import jakarta.inject.Singleton
@Singleton
class SnowflakeDirectLoadDatabaseInitialStatusGatherer(
tableOperationsClient: TableOperationsClient,
tempTableNameGenerator: TempTableNameGenerator,
catalog: DestinationCatalog,
) :
BaseDirectLoadInitialStatusGatherer(
tableOperationsClient,
tempTableNameGenerator,
catalog,
)

View File

@@ -1,95 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
import io.airbyte.cdk.ConfigErrorException
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.orchestration.db.ColumnNameGenerator
import io.airbyte.cdk.load.orchestration.db.FinalTableNameGenerator
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TypingDedupingUtil
import io.airbyte.cdk.load.table.TableName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.QUOTE
import jakarta.inject.Singleton
@Singleton
class SnowflakeFinalTableNameGenerator(private val config: SnowflakeConfiguration) :
FinalTableNameGenerator {
override fun getTableName(streamDescriptor: DestinationStream.Descriptor): TableName {
val namespace = streamDescriptor.namespace ?: config.schema
return if (!config.legacyRawTablesOnly) {
TableName(
namespace = namespace.toSnowflakeCompatibleName(),
name = streamDescriptor.name.toSnowflakeCompatibleName(),
)
} else {
TableName(
namespace = config.internalTableSchema,
name =
TypingDedupingUtil.concatenateRawTableName(
namespace = escapeJsonIdentifier(namespace),
name = escapeJsonIdentifier(streamDescriptor.name),
),
)
}
}
}
@Singleton
class SnowflakeColumnNameGenerator(private val config: SnowflakeConfiguration) :
ColumnNameGenerator {
override fun getColumnName(column: String): ColumnNameGenerator.ColumnName {
return if (!config.legacyRawTablesOnly) {
ColumnNameGenerator.ColumnName(
column.toSnowflakeCompatibleName(),
column.toSnowflakeCompatibleName(),
)
} else {
ColumnNameGenerator.ColumnName(
column,
column,
)
}
}
}
/**
* Escapes double-quotes in a JSON identifier by doubling them. This shit is legacy -- I don't know
* why this would be necessary but no harm in keeping it so I am keeping it.
*
* @return The escaped identifier.
*/
fun escapeJsonIdentifier(identifier: String): String {
// Note that we don't need to escape backslashes here!
// The only special character in an identifier is the double-quote, which needs to be
// doubled.
return identifier.replace(QUOTE, "$QUOTE$QUOTE")
}
/**
* Transforms a string to be compatible with Snowflake table and column names.
*
* @return The transformed string suitable for Snowflake identifiers.
*/
fun String.toSnowflakeCompatibleName(): String {
var identifier = this
// Handle empty strings
if (identifier.isEmpty()) {
throw ConfigErrorException("Empty string is invalid identifier")
}
// Snowflake scripting language does something weird when the `${` bigram shows up in the
// script so replace these with something else.
// For completeness, if we trigger this, also replace closing curly braces with underscores.
if (identifier.contains("\${")) {
identifier = identifier.replace("$", "_").replace("{", "_").replace("}", "_")
}
// Escape double quotes
identifier = escapeJsonIdentifier(identifier)
return identifier.uppercase()
}

View File

@@ -0,0 +1,116 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.schema
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_EXTRACTED_AT
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_GENERATION_ID
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_META
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_RAW_ID
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDataType
import jakarta.inject.Singleton
/**
* Manages column names and ordering for Snowflake tables based on whether legacy raw tables mode is
* enabled.
*
* TODO: We should add meta column munging and raw table support to the CDK, so this extra layer of
* management shouldn't be necessary.
*/
@Singleton
class SnowflakeColumnManager(
private val config: SnowflakeConfiguration,
) {
/**
* Get the list of column names for a table in the order they should appear in the CSV file and
* COPY INTO statement.
*
* Warning: MUST match the order defined in SnowflakeRecordFormatter
*
* @param columnSchema The schema containing column information (ignored in raw mode)
* @return List of column names in the correct order
*/
fun getTableColumnNames(columnSchema: ColumnSchema): List<String> {
return buildList {
addAll(getMetaColumnNames())
addAll(columnSchema.finalSchema.keys)
}
}
/**
* Get the list of Airbyte meta column names. In schema mode, these are uppercase. In raw mode,
* they are lowercase and included loaded_at
*
* @return Set of meta column names
*/
fun getMetaColumnNames(): Set<String> =
if (config.legacyRawTablesOnly) {
Constants.rawModeMetaColNames
} else {
Constants.schemaModeMetaColNames
}
/**
* Get the Airbyte meta columns as a map of column name to ColumnType. This provides both the
* column names and their types for table creation.
*
* @param columnSchema The user column schema (used to check for CDC columns)
* @return Map of meta column names to their types
*/
fun getMetaColumns(): LinkedHashMap<String, ColumnType> {
return if (config.legacyRawTablesOnly) {
Constants.rawModeMetaColumns
} else {
Constants.schemaModeMetaColumns
}
}
fun getGenerationIdColumnName(): String {
return if (config.legacyRawTablesOnly) {
Meta.COLUMN_NAME_AB_GENERATION_ID
} else {
SNOWFLAKE_AB_GENERATION_ID
}
}
object Constants {
val rawModeMetaColumns =
linkedMapOf(
Meta.COLUMN_NAME_AB_RAW_ID to ColumnType(SnowflakeDataType.VARCHAR.typeName, false),
Meta.COLUMN_NAME_AB_EXTRACTED_AT to
ColumnType(
SnowflakeDataType.TIMESTAMP_TZ.typeName,
false,
),
Meta.COLUMN_NAME_AB_META to ColumnType(SnowflakeDataType.VARIANT.typeName, false),
Meta.COLUMN_NAME_AB_GENERATION_ID to
ColumnType(
SnowflakeDataType.NUMBER.typeName,
true,
),
Meta.COLUMN_NAME_AB_LOADED_AT to
ColumnType(
SnowflakeDataType.TIMESTAMP_TZ.typeName,
true,
),
)
val schemaModeMetaColumns =
linkedMapOf(
SNOWFLAKE_AB_RAW_ID to ColumnType(SnowflakeDataType.VARCHAR.typeName, false),
SNOWFLAKE_AB_EXTRACTED_AT to
ColumnType(SnowflakeDataType.TIMESTAMP_TZ.typeName, false),
SNOWFLAKE_AB_META to ColumnType(SnowflakeDataType.VARIANT.typeName, false),
SNOWFLAKE_AB_GENERATION_ID to ColumnType(SnowflakeDataType.NUMBER.typeName, true),
)
val rawModeMetaColNames: Set<String> = rawModeMetaColumns.keys
val schemaModeMetaColNames: Set<String> = schemaModeMetaColumns.keys
}
}

View File

@@ -0,0 +1,34 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.schema
import io.airbyte.cdk.ConfigErrorException
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
/**
* Transforms a string to be compatible with Snowflake table and column names.
*
* @return The transformed string suitable for Snowflake identifiers.
*/
fun String.toSnowflakeCompatibleName(): String {
var identifier = this
// Handle empty strings
if (identifier.isEmpty()) {
throw ConfigErrorException("Empty string is invalid identifier")
}
// Snowflake scripting language does something weird when the `${` bigram shows up in the
// script so replace these with something else.
// For completeness, if we trigger this, also replace closing curly braces with underscores.
if (identifier.contains("\${")) {
identifier = identifier.replace("$", "_").replace("{", "_").replace("}", "_")
}
// Escape double quotes
identifier = escapeJsonIdentifier(identifier)
return identifier.uppercase()
}

View File

@@ -0,0 +1,121 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.schema
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.data.ArrayType
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
import io.airbyte.cdk.load.data.BooleanType
import io.airbyte.cdk.load.data.DateType
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerType
import io.airbyte.cdk.load.data.NumberType
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
import io.airbyte.cdk.load.data.UnionType
import io.airbyte.cdk.load.data.UnknownType
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.TableSchemaMapper
import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.cdk.load.table.TypingDedupingUtil
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDataType
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
import jakarta.inject.Singleton
@Singleton
class SnowflakeTableSchemaMapper(
private val config: SnowflakeConfiguration,
private val tempTableNameGenerator: TempTableNameGenerator,
) : TableSchemaMapper {
override fun toFinalTableName(desc: DestinationStream.Descriptor): TableName {
val namespace = desc.namespace ?: config.schema
return if (!config.legacyRawTablesOnly) {
TableName(
namespace = namespace.toSnowflakeCompatibleName(),
name = desc.name.toSnowflakeCompatibleName(),
)
} else {
TableName(
namespace = config.internalTableSchema,
name =
TypingDedupingUtil.concatenateRawTableName(
namespace = escapeJsonIdentifier(namespace),
name = escapeJsonIdentifier(desc.name),
),
)
}
}
override fun toTempTableName(tableName: TableName): TableName {
return tempTableNameGenerator.generate(tableName)
}
override fun toColumnName(name: String): String {
return if (!config.legacyRawTablesOnly) {
name.toSnowflakeCompatibleName()
} else {
// In legacy mode, column names are not transformed
name
}
}
override fun toColumnType(fieldType: FieldType): ColumnType {
val snowflakeType =
when (fieldType.type) {
// Simple types
BooleanType -> SnowflakeDataType.BOOLEAN.typeName
IntegerType -> SnowflakeDataType.NUMBER.typeName
NumberType -> SnowflakeDataType.FLOAT.typeName
StringType -> SnowflakeDataType.VARCHAR.typeName
// Temporal types
DateType -> SnowflakeDataType.DATE.typeName
TimeTypeWithTimezone -> SnowflakeDataType.VARCHAR.typeName
TimeTypeWithoutTimezone -> SnowflakeDataType.TIME.typeName
TimestampTypeWithTimezone -> SnowflakeDataType.TIMESTAMP_TZ.typeName
TimestampTypeWithoutTimezone -> SnowflakeDataType.TIMESTAMP_NTZ.typeName
// Semistructured types
is ArrayType,
ArrayTypeWithoutSchema -> SnowflakeDataType.ARRAY.typeName
is ObjectType,
ObjectTypeWithEmptySchema,
ObjectTypeWithoutSchema -> SnowflakeDataType.OBJECT.typeName
is UnionType -> SnowflakeDataType.VARIANT.typeName
is UnknownType -> SnowflakeDataType.VARIANT.typeName
}
return ColumnType(snowflakeType, fieldType.nullable)
}
override fun toFinalSchema(tableSchema: StreamTableSchema): StreamTableSchema {
if (!config.legacyRawTablesOnly) {
return tableSchema
}
return StreamTableSchema(
tableNames = tableSchema.tableNames,
columnSchema =
tableSchema.columnSchema.copy(
finalSchema =
mapOf(
Meta.COLUMN_NAME_DATA to
ColumnType(SnowflakeDataType.OBJECT.typeName, false)
)
),
importType = tableSchema.importType,
)
}
}

View File

@@ -1,201 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
import com.google.common.annotations.VisibleForTesting
import io.airbyte.cdk.load.data.AirbyteType
import io.airbyte.cdk.load.data.ArrayType
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
import io.airbyte.cdk.load.data.BooleanType
import io.airbyte.cdk.load.data.DateType
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerType
import io.airbyte.cdk.load.data.NumberType
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
import io.airbyte.cdk.load.data.UnionType
import io.airbyte.cdk.load.data.UnknownType
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.integrations.destination.snowflake.db.SnowflakeColumnNameGenerator
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import jakarta.inject.Singleton
import kotlin.collections.component1
import kotlin.collections.component2
import kotlin.collections.joinToString
import kotlin.collections.map
import kotlin.collections.plus
internal const val NOT_NULL = "NOT NULL"
internal val DEFAULT_COLUMNS =
listOf(
ColumnAndType(
columnName = COLUMN_NAME_AB_RAW_ID,
columnType = "${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL"
),
ColumnAndType(
columnName = COLUMN_NAME_AB_EXTRACTED_AT,
columnType = "${SnowflakeDataType.TIMESTAMP_TZ.typeName} $NOT_NULL"
),
ColumnAndType(
columnName = COLUMN_NAME_AB_META,
columnType = "${SnowflakeDataType.VARIANT.typeName} $NOT_NULL"
),
ColumnAndType(
columnName = COLUMN_NAME_AB_GENERATION_ID,
columnType = SnowflakeDataType.NUMBER.typeName
),
)
internal val RAW_DATA_COLUMN =
ColumnAndType(
columnName = COLUMN_NAME_DATA,
columnType = "${SnowflakeDataType.VARIANT.typeName} $NOT_NULL"
)
internal val RAW_COLUMNS =
listOf(
ColumnAndType(
columnName = COLUMN_NAME_AB_LOADED_AT,
columnType = SnowflakeDataType.TIMESTAMP_TZ.typeName
),
RAW_DATA_COLUMN
)
@Singleton
class SnowflakeColumnUtils(
private val snowflakeConfiguration: SnowflakeConfiguration,
private val snowflakeColumnNameGenerator: SnowflakeColumnNameGenerator,
) {
@VisibleForTesting
internal fun defaultColumns(): List<ColumnAndType> =
if (snowflakeConfiguration.legacyRawTablesOnly) {
DEFAULT_COLUMNS + RAW_COLUMNS
} else {
DEFAULT_COLUMNS
}
internal fun formattedDefaultColumns(): List<ColumnAndType> =
defaultColumns().map {
ColumnAndType(
columnName = formatColumnName(it.columnName, false),
columnType = it.columnType,
)
}
fun getGenerationIdColumnName(): String {
return if (snowflakeConfiguration.legacyRawTablesOnly) {
COLUMN_NAME_AB_GENERATION_ID
} else {
COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName()
}
}
fun getColumnNames(columnNameMapping: ColumnNameMapping): String =
if (snowflakeConfiguration.legacyRawTablesOnly) {
getFormattedDefaultColumnNames(true).joinToString(",")
} else {
(getFormattedDefaultColumnNames(true) +
columnNameMapping.map { (_, actualName) -> actualName.quote() })
.joinToString(",")
}
fun getFormattedDefaultColumnNames(quote: Boolean = false): List<String> =
defaultColumns().map { formatColumnName(it.columnName, quote) }
fun getFormattedColumnNames(
columns: Map<String, FieldType>,
columnNameMapping: ColumnNameMapping,
quote: Boolean = true,
): List<String> =
if (snowflakeConfiguration.legacyRawTablesOnly) {
getFormattedDefaultColumnNames(quote)
} else {
getFormattedDefaultColumnNames(quote) +
columns.map { (fieldName, _) ->
val columnName = columnNameMapping[fieldName] ?: fieldName
if (quote) columnName.quote() else columnName
}
}
fun columnsAndTypes(
columns: Map<String, FieldType>,
columnNameMapping: ColumnNameMapping
): List<ColumnAndType> =
if (snowflakeConfiguration.legacyRawTablesOnly) {
formattedDefaultColumns()
} else {
formattedDefaultColumns() +
columns.map { (fieldName, type) ->
val columnName = columnNameMapping[fieldName] ?: fieldName
val typeName = toDialectType(type.type)
ColumnAndType(
columnName = columnName,
columnType = if (type.nullable) typeName else "$typeName $NOT_NULL",
)
}
}
fun formatColumnName(
columnName: String,
quote: Boolean = true,
): String {
val formattedColumnName =
if (columnName == COLUMN_NAME_DATA) columnName
else snowflakeColumnNameGenerator.getColumnName(columnName).displayName
return if (quote) formattedColumnName.quote() else formattedColumnName
}
fun toDialectType(type: AirbyteType): String =
when (type) {
// Simple types
BooleanType -> SnowflakeDataType.BOOLEAN.typeName
IntegerType -> SnowflakeDataType.NUMBER.typeName
NumberType -> SnowflakeDataType.FLOAT.typeName
StringType -> SnowflakeDataType.VARCHAR.typeName
// Temporal types
DateType -> SnowflakeDataType.DATE.typeName
TimeTypeWithTimezone -> SnowflakeDataType.VARCHAR.typeName
TimeTypeWithoutTimezone -> SnowflakeDataType.TIME.typeName
TimestampTypeWithTimezone -> SnowflakeDataType.TIMESTAMP_TZ.typeName
TimestampTypeWithoutTimezone -> SnowflakeDataType.TIMESTAMP_NTZ.typeName
// Semistructured types
is ArrayType,
ArrayTypeWithoutSchema -> SnowflakeDataType.ARRAY.typeName
is ObjectType,
ObjectTypeWithEmptySchema,
ObjectTypeWithoutSchema -> SnowflakeDataType.OBJECT.typeName
is UnionType -> SnowflakeDataType.VARIANT.typeName
is UnknownType -> SnowflakeDataType.VARIANT.typeName
}
}
data class ColumnAndType(val columnName: String, val columnType: String) {
override fun toString(): String {
return "${columnName.quote()} $columnType"
}
}
/**
* Surrounds the string instance with double quotation marks (e.g. "some string" -> "\"some
* string\"").
*/
fun String.quote() = "$QUOTE$this$QUOTE"

View File

@@ -10,7 +10,7 @@ package io.airbyte.integrations.destination.snowflake.sql
*/
enum class SnowflakeDataType(val typeName: String) {
// Numeric types
NUMBER("NUMBER(38,0)"),
NUMBER("NUMBER"),
FLOAT("FLOAT"),
// String & binary types

View File

@@ -4,16 +4,19 @@
package io.airbyte.integrations.destination.snowflake.sql
import io.airbyte.cdk.load.command.Dedupe
import io.airbyte.cdk.load.command.DestinationStream
import com.google.common.annotations.VisibleForTesting
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.component.ColumnTypeChange
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.util.UUIDGenerator
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.write.load.CSV_FIELD_SEPARATOR
@@ -22,6 +25,14 @@ import io.github.oshai.kotlinlogging.KotlinLogging
import jakarta.inject.Singleton
internal const val COUNT_TOTAL_ALIAS = "TOTAL"
internal const val NOT_NULL = "NOT NULL"
// Snowflake-compatible (uppercase) versions of the Airbyte meta column names
internal val SNOWFLAKE_AB_RAW_ID = COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName()
internal val SNOWFLAKE_AB_EXTRACTED_AT = COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()
internal val SNOWFLAKE_AB_META = COLUMN_NAME_AB_META.toSnowflakeCompatibleName()
internal val SNOWFLAKE_AB_GENERATION_ID = COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName()
internal val SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN = CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()
private val log = KotlinLogging.logger {}
@@ -36,80 +47,91 @@ fun String.andLog(): String {
@Singleton
class SnowflakeDirectLoadSqlGenerator(
private val columnUtils: SnowflakeColumnUtils,
private val uuidGenerator: UUIDGenerator,
private val snowflakeConfiguration: SnowflakeConfiguration,
private val snowflakeSqlNameUtils: SnowflakeSqlNameUtils,
private val config: SnowflakeConfiguration,
private val columnManager: SnowflakeColumnManager,
) {
fun countTable(tableName: TableName): String {
return "SELECT COUNT(*) AS $COUNT_TOTAL_ALIAS FROM ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog()
return "SELECT COUNT(*) AS $COUNT_TOTAL_ALIAS FROM ${fullyQualifiedName(tableName)}".andLog()
}
fun createNamespace(namespace: String): String {
return "CREATE SCHEMA IF NOT EXISTS ${snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)}".andLog()
return "CREATE SCHEMA IF NOT EXISTS ${fullyQualifiedNamespace(namespace)}".andLog()
}
fun createTable(
stream: DestinationStream,
tableName: TableName,
columnNameMapping: ColumnNameMapping,
tableSchema: StreamTableSchema,
replace: Boolean
): String {
val finalSchema = tableSchema.columnSchema.finalSchema
val metaColumns = columnManager.getMetaColumns()
// Build column declarations from the meta columns and user schema
val columnDeclarations =
columnUtils
.columnsAndTypes(stream.schema.asColumns(), columnNameMapping)
.joinToString(",\n")
buildList {
// Add Airbyte meta columns from the column manager
metaColumns.forEach { (columnName, columnType) ->
val nullability = if (columnType.nullable) "" else " NOT NULL"
add("${columnName.quote()} ${columnType.type}$nullability")
}
// Add user columns from the munged schema
finalSchema.forEach { (columnName, columnType) ->
val nullability = if (columnType.nullable) "" else " NOT NULL"
add("${columnName.quote()} ${columnType.type}$nullability")
}
}
.joinToString(",\n ")
// Snowflake supports CREATE OR REPLACE TABLE, which is simpler than drop+recreate
val createOrReplace = if (replace) "CREATE OR REPLACE" else "CREATE"
val createTableStatement =
"""
$createOrReplace TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)} (
$columnDeclarations
)
""".trimIndent()
|$createOrReplace TABLE ${fullyQualifiedName(tableName)} (
| $columnDeclarations
|)
""".trimMargin() // Something was tripping up trimIndent so we opt for trimMargin
return createTableStatement.andLog()
}
fun showColumns(tableName: TableName): String =
"SHOW COLUMNS IN TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog()
"SHOW COLUMNS IN TABLE ${fullyQualifiedName(tableName)}".andLog()
fun copyTable(
columnNameMapping: ColumnNameMapping,
columnNames: Set<String>,
sourceTableName: TableName,
targetTableName: TableName
): String {
val columnNames = columnUtils.getColumnNames(columnNameMapping)
val columnList = columnNames.joinToString(", ") { it.quote() }
return """
INSERT INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)}
INSERT INTO ${fullyQualifiedName(targetTableName)}
(
$columnNames
$columnList
)
SELECT
$columnNames
FROM ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)}
$columnList
FROM ${fullyQualifiedName(sourceTableName)}
"""
.trimIndent()
.andLog()
}
fun upsertTable(
stream: DestinationStream,
columnNameMapping: ColumnNameMapping,
tableSchema: StreamTableSchema,
sourceTableName: TableName,
targetTableName: TableName
): String {
val importType = stream.importType as Dedupe
val finalSchema = tableSchema.columnSchema.finalSchema
// Build primary key matching condition
val pks = tableSchema.getPrimaryKey().flatten()
val pkEquivalent =
if (importType.primaryKey.isNotEmpty()) {
importType.primaryKey.joinToString(" AND ") { fieldPath ->
val fieldName = fieldPath.first()
val columnName = columnNameMapping[fieldName] ?: fieldName
if (pks.isNotEmpty()) {
pks.joinToString(" AND ") { columnName ->
val targetTableColumnName = "target_table.${columnName.quote()}"
val newRecordColumnName = "new_record.${columnName.quote()}"
"""($targetTableColumnName = $newRecordColumnName OR ($targetTableColumnName IS NULL AND $newRecordColumnName IS NULL))"""
@@ -120,80 +142,62 @@ class SnowflakeDirectLoadSqlGenerator(
}
// Build column lists for INSERT and UPDATE
val columnList: String =
columnUtils
.getFormattedColumnNames(
columns = stream.schema.asColumns(),
columnNameMapping = columnNameMapping,
quote = false,
)
.joinToString(
",\n",
) {
it.quote()
}
val allColumns = buildList {
add(SNOWFLAKE_AB_RAW_ID)
add(SNOWFLAKE_AB_EXTRACTED_AT)
add(SNOWFLAKE_AB_META)
add(SNOWFLAKE_AB_GENERATION_ID)
addAll(finalSchema.keys)
}
val columnList: String = allColumns.joinToString(",\n ") { it.quote() }
val newRecordColumnList: String =
columnUtils
.getFormattedColumnNames(
columns = stream.schema.asColumns(),
columnNameMapping = columnNameMapping,
quote = false,
)
.joinToString(",\n") { "new_record.${it.quote()}" }
allColumns.joinToString(",\n ") { "new_record.${it.quote()}" }
// Get deduped records from source
val selectSourceRecords = selectDedupedRecords(stream, sourceTableName, columnNameMapping)
val selectSourceRecords = selectDedupedRecords(tableSchema, sourceTableName)
// Build cursor comparison for determining which record is newer
val cursorComparison: String
if (importType.cursor.isNotEmpty()) {
val cursorFieldName = importType.cursor.first()
val cursor = (columnNameMapping[cursorFieldName] ?: cursorFieldName)
val cursor = tableSchema.getCursor().firstOrNull()
if (cursor != null) {
val targetTableCursor = "target_table.${cursor.quote()}"
val newRecordCursor = "new_record.${cursor.quote()}"
cursorComparison =
"""
(
$targetTableCursor < $newRecordCursor
OR ($targetTableCursor = $newRecordCursor AND target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}")
OR ($targetTableCursor IS NULL AND $newRecordCursor IS NULL AND target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}")
OR ($targetTableCursor = $newRecordCursor AND target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT")
OR ($targetTableCursor IS NULL AND $newRecordCursor IS NULL AND target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT")
OR ($targetTableCursor IS NULL AND $newRecordCursor IS $NOT_NULL)
)
""".trimIndent()
} else {
// No cursor - use extraction timestamp only
cursorComparison =
"""target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}""""
"""target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT""""
}
// Build column assignments for UPDATE
val columnAssignments: String =
columnUtils
.getFormattedColumnNames(
columns = stream.schema.asColumns(),
columnNameMapping = columnNameMapping,
quote = false,
)
.joinToString(",\n") { column ->
"${column.quote()} = new_record.${column.quote()}"
}
allColumns.joinToString(",\n ") { column ->
"${column.quote()} = new_record.${column.quote()}"
}
// Handle CDC deletions based on mode
val cdcDeleteClause: String
val cdcSkipInsertClause: String
if (
stream.schema.asColumns().containsKey(CDC_DELETED_AT_COLUMN) &&
snowflakeConfiguration.cdcDeletionMode == CdcDeletionMode.HARD_DELETE
finalSchema.containsKey(SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN) &&
config.cdcDeletionMode == CdcDeletionMode.HARD_DELETE
) {
// Execute CDC deletions if there's already a record
cdcDeleteClause =
"WHEN MATCHED AND new_record.\"${CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()}\" IS NOT NULL AND $cursorComparison THEN DELETE"
"WHEN MATCHED AND new_record.\"${SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN}\" IS NOT NULL AND $cursorComparison THEN DELETE"
// And skip insertion entirely if there's no matching record.
// (This is possible if a single T+D batch contains both an insertion and deletion for
// the same PK)
cdcSkipInsertClause =
"AND new_record.\"${CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()}\" IS NULL"
cdcSkipInsertClause = "AND new_record.\"${SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN}\" IS NULL"
} else {
cdcDeleteClause = ""
cdcSkipInsertClause = ""
@@ -203,35 +207,35 @@ class SnowflakeDirectLoadSqlGenerator(
val mergeStatement =
if (cdcDeleteClause.isNotEmpty()) {
"""
MERGE INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} AS target_table
USING (
$selectSourceRecords
) AS new_record
ON $pkEquivalent
$cdcDeleteClause
WHEN MATCHED AND $cursorComparison THEN UPDATE SET
$columnAssignments
WHEN NOT MATCHED $cdcSkipInsertClause THEN INSERT (
$columnList
) VALUES (
$newRecordColumnList
)
""".trimIndent()
|MERGE INTO ${fullyQualifiedName(targetTableName)} AS target_table
|USING (
|$selectSourceRecords
|) AS new_record
|ON $pkEquivalent
|$cdcDeleteClause
|WHEN MATCHED AND $cursorComparison THEN UPDATE SET
| $columnAssignments
|WHEN NOT MATCHED $cdcSkipInsertClause THEN INSERT (
| $columnList
|) VALUES (
| $newRecordColumnList
|)
""".trimMargin()
} else {
"""
MERGE INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} AS target_table
USING (
$selectSourceRecords
) AS new_record
ON $pkEquivalent
WHEN MATCHED AND $cursorComparison THEN UPDATE SET
$columnAssignments
WHEN NOT MATCHED THEN INSERT (
$columnList
) VALUES (
$newRecordColumnList
)
""".trimIndent()
|MERGE INTO ${fullyQualifiedName(targetTableName)} AS target_table
|USING (
|$selectSourceRecords
|) AS new_record
|ON $pkEquivalent
|WHEN MATCHED AND $cursorComparison THEN UPDATE SET
| $columnAssignments
|WHEN NOT MATCHED THEN INSERT (
| $columnList
|) VALUES (
| $newRecordColumnList
|)
""".trimMargin()
}
return mergeStatement.andLog()
@@ -242,75 +246,66 @@ class SnowflakeDirectLoadSqlGenerator(
* table. Uses ROW_NUMBER() window function to select the most recent record per primary key.
*/
private fun selectDedupedRecords(
stream: DestinationStream,
sourceTableName: TableName,
columnNameMapping: ColumnNameMapping
tableSchema: StreamTableSchema,
sourceTableName: TableName
): String {
val columnList: String =
columnUtils
.getFormattedColumnNames(
columns = stream.schema.asColumns(),
columnNameMapping = columnNameMapping,
quote = false,
)
.joinToString(
",\n",
) {
it.quote()
}
val importType = stream.importType as Dedupe
val allColumns = buildList {
add(SNOWFLAKE_AB_RAW_ID)
add(SNOWFLAKE_AB_EXTRACTED_AT)
add(SNOWFLAKE_AB_META)
add(SNOWFLAKE_AB_GENERATION_ID)
addAll(tableSchema.columnSchema.finalSchema.keys)
}
val columnList: String = allColumns.joinToString(",\n ") { it.quote() }
// Build the primary key list for partitioning
val pks = tableSchema.getPrimaryKey().flatten()
val pkList =
if (importType.primaryKey.isNotEmpty()) {
importType.primaryKey.joinToString(",") { fieldPath ->
(columnNameMapping[fieldPath.first()] ?: fieldPath.first()).quote()
}
if (pks.isNotEmpty()) {
pks.joinToString(",") { it.quote() }
} else {
// Should not happen as we check this earlier, but handle it defensively
throw IllegalArgumentException("Cannot deduplicate without primary key")
}
// Build cursor order clause for sorting within each partition
val cursor = tableSchema.getCursor().firstOrNull()
val cursorOrderClause =
if (importType.cursor.isNotEmpty()) {
val columnName =
(columnNameMapping[importType.cursor.first()] ?: importType.cursor.first())
.quote()
"$columnName DESC NULLS LAST,"
if (cursor != null) {
"${cursor.quote()} DESC NULLS LAST,"
} else {
""
}
return """
WITH records AS (
SELECT
$columnList
FROM ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)}
), numbered_rows AS (
SELECT *, ROW_NUMBER() OVER (
PARTITION BY $pkList ORDER BY $cursorOrderClause "${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" DESC
) AS row_number
FROM records
)
SELECT $columnList
FROM numbered_rows
WHERE row_number = 1
| WITH records AS (
| SELECT
| $columnList
| FROM ${fullyQualifiedName(sourceTableName)}
| ), numbered_rows AS (
| SELECT *, ROW_NUMBER() OVER (
| PARTITION BY $pkList ORDER BY $cursorOrderClause "$SNOWFLAKE_AB_EXTRACTED_AT" DESC
| ) AS row_number
| FROM records
| )
| SELECT $columnList
| FROM numbered_rows
| WHERE row_number = 1
"""
.trimIndent()
.trimMargin()
.andLog()
}
fun dropTable(tableName: TableName): String {
return "DROP TABLE IF EXISTS ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog()
return "DROP TABLE IF EXISTS ${fullyQualifiedName(tableName)}".andLog()
}
fun getGenerationId(
tableName: TableName,
): String {
return """
SELECT "${columnUtils.getGenerationIdColumnName()}"
FROM ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}
SELECT "${columnManager.getGenerationIdColumnName()}"
FROM ${fullyQualifiedName(tableName)}
LIMIT 1
"""
.trimIndent()
@@ -318,12 +313,12 @@ class SnowflakeDirectLoadSqlGenerator(
}
fun createSnowflakeStage(tableName: TableName): String {
val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName)
val stageName = fullyQualifiedStageName(tableName)
return "CREATE STAGE IF NOT EXISTS $stageName".andLog()
}
fun putInStage(tableName: TableName, tempFilePath: String): String {
val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true)
val stageName = fullyQualifiedStageName(tableName, true)
return """
PUT 'file://$tempFilePath' '@$stageName'
AUTO_COMPRESS = FALSE
@@ -334,35 +329,45 @@ class SnowflakeDirectLoadSqlGenerator(
.andLog()
}
fun copyFromStage(tableName: TableName, filename: String): String {
val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true)
fun copyFromStage(
tableName: TableName,
filename: String,
columnNames: List<String>? = null
): String {
val stageName = fullyQualifiedStageName(tableName, true)
val columnList =
columnNames?.let { names -> "(${names.joinToString(", ") { it.quote() }})" } ?: ""
return """
COPY INTO ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}
FROM '@$stageName'
FILE_FORMAT = (
TYPE = 'CSV'
COMPRESSION = GZIP
FIELD_DELIMITER = '$CSV_FIELD_SEPARATOR'
RECORD_DELIMITER = '$CSV_LINE_DELIMITER'
FIELD_OPTIONALLY_ENCLOSED_BY = '"'
TRIM_SPACE = TRUE
ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE
REPLACE_INVALID_CHARACTERS = TRUE
ESCAPE = NONE
ESCAPE_UNENCLOSED_FIELD = NONE
)
ON_ERROR = 'ABORT_STATEMENT'
PURGE = TRUE
files = ('$filename')
|COPY INTO ${fullyQualifiedName(tableName)}$columnList
|FROM '@$stageName'
|FILE_FORMAT = (
| TYPE = 'CSV'
| COMPRESSION = GZIP
| FIELD_DELIMITER = '$CSV_FIELD_SEPARATOR'
| RECORD_DELIMITER = '$CSV_LINE_DELIMITER'
| FIELD_OPTIONALLY_ENCLOSED_BY = '"'
| TRIM_SPACE = TRUE
| ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE
| REPLACE_INVALID_CHARACTERS = TRUE
| ESCAPE = NONE
| ESCAPE_UNENCLOSED_FIELD = NONE
|)
|ON_ERROR = 'ABORT_STATEMENT'
|PURGE = TRUE
|files = ('$filename')
"""
.trimIndent()
.trimMargin()
.andLog()
}
fun swapTableWith(sourceTableName: TableName, targetTableName: TableName): String {
return """
ALTER TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} SWAP WITH ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)}
ALTER TABLE ${fullyQualifiedName(sourceTableName)} SWAP WITH ${
fullyQualifiedName(
targetTableName,
)
}
"""
.trimIndent()
.andLog()
@@ -372,7 +377,11 @@ class SnowflakeDirectLoadSqlGenerator(
// Snowflake RENAME TO only accepts the table name, not a fully qualified name
// The renamed table stays in the same schema
return """
ALTER TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} RENAME TO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)}
ALTER TABLE ${fullyQualifiedName(sourceTableName)} RENAME TO ${
fullyQualifiedName(
targetTableName,
)
}
"""
.trimIndent()
.andLog()
@@ -382,7 +391,7 @@ class SnowflakeDirectLoadSqlGenerator(
schemaName: String,
tableName: String,
): String =
"""DESCRIBE TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(TableName(schemaName, tableName))}""".andLog()
"""DESCRIBE TABLE ${fullyQualifiedName(TableName(schemaName, tableName))}""".andLog()
fun alterTable(
tableName: TableName,
@@ -391,14 +400,14 @@ class SnowflakeDirectLoadSqlGenerator(
modifiedColumns: Map<String, ColumnTypeChange>,
): Set<String> {
val clauses = mutableSetOf<String>()
val prettyTableName = snowflakeSqlNameUtils.fullyQualifiedName(tableName)
val prettyTableName = fullyQualifiedName(tableName)
addedColumns.forEach { (name, columnType) ->
clauses.add(
// Note that we intentionally don't set NOT NULL.
// We're adding a new column, and we don't know what constitutes a reasonable
// default value for preexisting records.
// So we add the column as nullable.
"ALTER TABLE $prettyTableName ADD COLUMN ${name.quote()} ${columnType.type};".andLog()
"ALTER TABLE $prettyTableName ADD COLUMN ${name.quote()} ${columnType.type};".andLog(),
)
}
deletedColumns.forEach {
@@ -412,35 +421,34 @@ class SnowflakeDirectLoadSqlGenerator(
val tempColumn = "${name}_${uuidGenerator.v4()}"
clauses.add(
// As above: we add the column as nullable.
"ALTER TABLE $prettyTableName ADD COLUMN ${tempColumn.quote()} ${typeChange.newType.type};".andLog()
"ALTER TABLE $prettyTableName ADD COLUMN ${tempColumn.quote()} ${typeChange.newType.type};".andLog(),
)
clauses.add(
"UPDATE $prettyTableName SET ${tempColumn.quote()} = CAST(${name.quote()} AS ${typeChange.newType.type});".andLog()
"UPDATE $prettyTableName SET ${tempColumn.quote()} = CAST(${name.quote()} AS ${typeChange.newType.type});".andLog(),
)
val backupColumn = "${tempColumn}_backup"
clauses.add(
"""
ALTER TABLE $prettyTableName
RENAME COLUMN "$name" TO "$backupColumn";
""".trimIndent()
""".trimIndent(),
)
clauses.add(
"""
ALTER TABLE $prettyTableName
RENAME COLUMN "$tempColumn" TO "$name";
""".trimIndent()
""".trimIndent(),
)
clauses.add(
"ALTER TABLE $prettyTableName DROP COLUMN ${backupColumn.quote()};".andLog()
"ALTER TABLE $prettyTableName DROP COLUMN ${backupColumn.quote()};".andLog(),
)
} else if (!typeChange.originalType.nullable && typeChange.newType.nullable) {
// If the type is unchanged, we can change a column from NOT NULL to nullable.
// But we'll never do the reverse, because there's a decent chance that historical
// records
// had null values.
// records had null values.
// Users can always manually ALTER COLUMN ... SET NOT NULL if they want.
clauses.add(
"""ALTER TABLE $prettyTableName ALTER COLUMN "$name" DROP NOT NULL;""".andLog()
"""ALTER TABLE $prettyTableName ALTER COLUMN "$name" DROP NOT NULL;""".andLog(),
)
} else {
log.info {
@@ -450,4 +458,45 @@ class SnowflakeDirectLoadSqlGenerator(
}
return clauses
}
@VisibleForTesting
fun fullyQualifiedName(tableName: TableName): String =
combineParts(listOf(getDatabaseName(), tableName.namespace, tableName.name))
@VisibleForTesting
fun fullyQualifiedNamespace(namespace: String) =
combineParts(listOf(getDatabaseName(), namespace))
@VisibleForTesting
fun fullyQualifiedStageName(tableName: TableName, escape: Boolean = false): String {
val currentTableName =
if (escape) {
tableName.name
} else {
tableName.name
}
return combineParts(
parts =
listOf(
getDatabaseName(),
tableName.namespace,
"$STAGE_NAME_PREFIX$currentTableName",
),
escape = escape,
)
}
@VisibleForTesting
internal fun combineParts(parts: List<String>, escape: Boolean = false): String =
parts
.map { if (escape) sqlEscape(it) else it }
.joinToString(separator = ".") {
if (!it.startsWith(QUOTE)) {
"$QUOTE$it$QUOTE"
} else {
it
}
}
private fun getDatabaseName() = config.database.toSnowflakeCompatibleName()
}

View File

@@ -0,0 +1,29 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
const val STAGE_NAME_PREFIX = "airbyte_stage_"
internal const val QUOTE: String = "\""
fun sqlEscape(part: String) = part.replace("\\", "\\\\").replace("'", "\\'").replace("\"", "\\\"")
/**
* Surrounds the string instance with double quotation marks (e.g. "some string" -> "\"some
* string\"").
*/
fun String.quote() = "$QUOTE$this$QUOTE"
/**
* Escapes double-quotes in a JSON identifier by doubling them. This is legacy -- I don't know why
* this would be necessary but no harm in keeping it, so I am keeping it.
*
* @return The escaped identifier.
*/
fun escapeJsonIdentifier(identifier: String): String {
// Note that we don't need to escape backslashes here!
// The only special character in an identifier is the double-quote, which needs to be
// doubled.
return identifier.replace(QUOTE, "$QUOTE$QUOTE")
}

View File

@@ -1,57 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
import io.airbyte.cdk.load.table.TableName
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import jakarta.inject.Singleton
const val STAGE_NAME_PREFIX = "airbyte_stage_"
internal const val QUOTE: String = "\""
fun sqlEscape(part: String) = part.replace("\\", "\\\\").replace("'", "\\'").replace("\"", "\\\"")
@Singleton
class SnowflakeSqlNameUtils(
private val snowflakeConfiguration: SnowflakeConfiguration,
) {
fun fullyQualifiedName(tableName: TableName): String =
combineParts(listOf(getDatabaseName(), tableName.namespace, tableName.name))
fun fullyQualifiedNamespace(namespace: String) =
combineParts(listOf(getDatabaseName(), namespace))
fun fullyQualifiedStageName(tableName: TableName, escape: Boolean = false): String {
val currentTableName =
if (escape) {
tableName.name
} else {
tableName.name
}
return combineParts(
parts =
listOf(
getDatabaseName(),
tableName.namespace,
"$STAGE_NAME_PREFIX$currentTableName"
),
escape = escape,
)
}
fun combineParts(parts: List<String>, escape: Boolean = false): String =
parts
.map { if (escape) sqlEscape(it) else it }
.joinToString(separator = ".") {
if (!it.startsWith(QUOTE)) {
"$QUOTE$it$QUOTE"
} else {
it
}
}
private fun getDatabaseName() = snowflakeConfiguration.database.toSnowflakeCompatibleName()
}

View File

@@ -6,38 +6,39 @@ package io.airbyte.integrations.destination.snowflake.write
import io.airbyte.cdk.SystemErrorException
import io.airbyte.cdk.load.command.Dedupe
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.orchestration.db.DatabaseInitialStatusGatherer
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadInitialStatus
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendStreamLoader
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendTruncateStreamLoader
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableDedupStreamLoader
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableDedupTruncateStreamLoader
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.DatabaseInitialStatusGatherer
import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.cdk.load.table.directload.DirectLoadInitialStatus
import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendStreamLoader
import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendTruncateStreamLoader
import io.airbyte.cdk.load.table.directload.DirectLoadTableDedupStreamLoader
import io.airbyte.cdk.load.table.directload.DirectLoadTableDedupTruncateStreamLoader
import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.write.DestinationWriter
import io.airbyte.cdk.load.write.StreamLoader
import io.airbyte.cdk.load.write.StreamStateStore
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
import jakarta.inject.Singleton
@Singleton
class SnowflakeWriter(
private val names: TableCatalog,
private val catalog: DestinationCatalog,
private val stateGatherer: DatabaseInitialStatusGatherer<DirectLoadInitialStatus>,
private val streamStateStore: StreamStateStore<DirectLoadTableExecutionConfig>,
private val snowflakeClient: SnowflakeAirbyteClient,
private val tempTableNameGenerator: TempTableNameGenerator,
private val snowflakeConfiguration: SnowflakeConfiguration,
private val tempTableNameGenerator: TempTableNameGenerator,
) : DestinationWriter {
private lateinit var initialStatuses: Map<DestinationStream, DirectLoadInitialStatus>
override suspend fun setup() {
names.values
.map { (tableNames, _) -> tableNames.finalTableName!!.namespace }
catalog.streams
.map { it.tableSchema.tableNames.finalTableName!!.namespace }
.toSet()
.forEach { snowflakeClient.createNamespace(it) }
@@ -45,15 +46,15 @@ class SnowflakeWriter(
escapeJsonIdentifier(snowflakeConfiguration.internalTableSchema)
)
initialStatuses = stateGatherer.gatherInitialStatus(names)
initialStatuses = stateGatherer.gatherInitialStatus()
}
override fun createStreamLoader(stream: DestinationStream): StreamLoader {
val initialStatus = initialStatuses[stream]!!
val tableNameInfo = names[stream]!!
val realTableName = tableNameInfo.tableNames.finalTableName!!
val tempTableName = tempTableNameGenerator.generate(realTableName)
val columnNameMapping = tableNameInfo.columnNameMapping
val realTableName = stream.tableSchema.tableNames.finalTableName!!
val tempTableName = stream.tableSchema.tableNames.tempTableName!!
val columnNameMapping =
ColumnNameMapping(stream.tableSchema.columnSchema.inputToFinalColumnNames)
return when (stream.minimumGenerationId) {
0L ->
when (stream.importType) {

View File

@@ -9,11 +9,12 @@ import de.siegmar.fastcsv.writer.CsvWriter
import de.siegmar.fastcsv.writer.LineDelimiter
import de.siegmar.fastcsv.writer.QuoteStrategies
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.QUOTE
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.github.oshai.kotlinlogging.KotlinLogging
import java.io.File
import java.io.OutputStream
@@ -36,10 +37,11 @@ private const val CSV_WRITER_BUFFER_SIZE = 1024 * 1024 // 1 MB
class SnowflakeInsertBuffer(
private val tableName: TableName,
val columns: LinkedHashMap<String, String>,
private val snowflakeClient: SnowflakeAirbyteClient,
val snowflakeConfiguration: SnowflakeConfiguration,
val snowflakeColumnUtils: SnowflakeColumnUtils,
val columnSchema: ColumnSchema,
private val columnManager: SnowflakeColumnManager,
private val snowflakeRecordFormatter: SnowflakeRecordFormatter,
private val flushLimit: Int = DEFAULT_FLUSH_LIMIT,
) {
@@ -57,12 +59,6 @@ class SnowflakeInsertBuffer(
.lineDelimiter(CSV_LINE_DELIMITER)
.quoteStrategy(QuoteStrategies.REQUIRED)
private val snowflakeRecordFormatter: SnowflakeRecordFormatter =
when (snowflakeConfiguration.legacyRawTablesOnly) {
true -> SnowflakeRawRecordFormatter(columns, snowflakeColumnUtils)
else -> SnowflakeSchemaRecordFormatter(columns, snowflakeColumnUtils)
}
fun accumulate(recordFields: Map<String, AirbyteValue>) {
if (csvFilePath == null) {
val csvFile = createCsvFile()
@@ -92,7 +88,9 @@ class SnowflakeInsertBuffer(
"Copying staging data into ${tableName.toPrettyString(quote = QUOTE)}..."
}
// Finally, copy the data from the staging table to the final table
snowflakeClient.copyFromStage(tableName, filePath.fileName.toString())
// Pass column names to ensure correct mapping even after ALTER TABLE operations
val columnNames = columnManager.getTableColumnNames(columnSchema)
snowflakeClient.copyFromStage(tableName, filePath.fileName.toString(), columnNames)
logger.info {
"Finished insert of $recordCount row(s) into ${tableName.toPrettyString(quote = QUOTE)}."
}
@@ -117,7 +115,9 @@ class SnowflakeInsertBuffer(
private fun writeToCsvFile(record: Map<String, AirbyteValue>) {
csvWriter?.let {
it.writeRecord(snowflakeRecordFormatter.format(record).map { col -> col.toString() })
it.writeRecord(
snowflakeRecordFormatter.format(record, columnSchema).map { col -> col.toString() }
)
recordCount++
if ((recordCount % flushLimit) == 0) {
it.flush()

View File

@@ -8,99 +8,62 @@ import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.NullValue
import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.data.csv.toCsvValue
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.util.Jsons
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
interface SnowflakeRecordFormatter {
fun format(record: Map<String, AirbyteValue>): List<Any>
fun format(record: Map<String, AirbyteValue>, columnSchema: ColumnSchema): List<Any>
}
class SnowflakeSchemaRecordFormatter(
private val columns: LinkedHashMap<String, String>,
val snowflakeColumnUtils: SnowflakeColumnUtils,
) : SnowflakeRecordFormatter {
class SnowflakeSchemaRecordFormatter : SnowflakeRecordFormatter {
override fun format(record: Map<String, AirbyteValue>, columnSchema: ColumnSchema): List<Any> {
val result = mutableListOf<Any>()
val userColumns = columnSchema.finalSchema.keys
private val airbyteColumnNames =
snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet()
// WARNING: MUST match the order defined in SnowflakeColumnManager#getTableColumnNames
//
// Why don't we just use that here? Well, unlike the user fields, the meta fields on the
// record are not munged for the destination. So we must access the values for those columns
// using the original lowercase meta key.
result.add(record[COLUMN_NAME_AB_RAW_ID].toCsvValue())
result.add(record[COLUMN_NAME_AB_EXTRACTED_AT].toCsvValue())
result.add(record[COLUMN_NAME_AB_META].toCsvValue())
result.add(record[COLUMN_NAME_AB_GENERATION_ID].toCsvValue())
override fun format(record: Map<String, AirbyteValue>): List<Any> =
columns.map { (columnName, _) ->
/*
* Meta columns are forced to uppercase for backwards compatibility with previous
* versions of the destination. Therefore, convert the column to lowercase so
* that it can match the constants, which use the lowercase version of the meta
* column names.
*/
if (airbyteColumnNames.contains(columnName)) {
record[columnName.lowercase()].toCsvValue()
} else {
record.keys
// The columns retrieved from Snowflake do not have any escaping applied.
// Therefore, re-apply the compatible name escaping to the name of the
// columns retrieved from Snowflake. The record keys should already have
// been escaped by the CDK before arriving at the aggregate, so no need
// to escape again here.
.find { it == columnName.toSnowflakeCompatibleName() }
?.let { record[it].toCsvValue() }
?: ""
}
}
// Add user columns from the final schema
userColumns.forEach { columnName -> result.add(record[columnName].toCsvValue()) }
return result
}
}
class SnowflakeRawRecordFormatter(
columns: LinkedHashMap<String, String>,
val snowflakeColumnUtils: SnowflakeColumnUtils,
) : SnowflakeRecordFormatter {
private val columns = columns.keys
class SnowflakeRawRecordFormatter : SnowflakeRecordFormatter {
private val airbyteColumnNames =
snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet()
override fun format(record: Map<String, AirbyteValue>): List<Any> =
override fun format(record: Map<String, AirbyteValue>, columnSchema: ColumnSchema): List<Any> =
toOutputRecord(record.toMutableMap())
private fun toOutputRecord(record: MutableMap<String, AirbyteValue>): List<Any> {
val outputRecord = mutableListOf<Any>()
// Copy the Airbyte metadata columns to the raw output, removing each
// one from the record to avoid duplicates in the "data" field
columns
.filter { airbyteColumnNames.contains(it) && it != Meta.COLUMN_NAME_DATA }
.forEach { column -> safeAddToOutput(column, record, outputRecord) }
val mutableRecord = record.toMutableMap()
// Add meta columns in order (except _airbyte_data which we handle specially)
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_RAW_ID)?.toCsvValue() ?: "")
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_EXTRACTED_AT)?.toCsvValue() ?: "")
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_META)?.toCsvValue() ?: "")
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_GENERATION_ID)?.toCsvValue() ?: "")
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_LOADED_AT)?.toCsvValue() ?: "")
// Do not output null values in the JSON raw output
val filteredRecord = record.filter { (_, v) -> v !is NullValue }
// Convert all the remaining columns in the record to a JSON document stored in the "data"
// column. Add it in the same position as the _airbyte_data column in the column list to
// ensure it is inserted into the proper column in the table.
insert(
columns.indexOf(Meta.COLUMN_NAME_DATA),
StringValue(Jsons.writeValueAsString(filteredRecord)).toCsvValue(),
outputRecord
)
val filteredRecord = mutableRecord.filter { (_, v) -> v !is NullValue }
// Convert all the remaining columns to a JSON document stored in the "data" column
outputRecord.add(StringValue(Jsons.writeValueAsString(filteredRecord)).toCsvValue())
return outputRecord
}
private fun safeAddToOutput(
key: String,
record: MutableMap<String, AirbyteValue>,
output: MutableList<Any>
) {
val extractedValue = record.remove(key)
// Ensure that the data is inserted into the list at the same position as the column
insert(columns.indexOf(key), extractedValue?.toCsvValue() ?: "", output)
}
private fun insert(index: Int, value: Any, list: MutableList<Any>) {
/*
* Attempt to insert the value into the proper order in the list. If the index
* is already present in the list, use the add(index, element) method to insert it
* into the proper order and push everything to the right. If the index is at the
* end of the list, just use add(element) to insert it at the end. If the index
* is further beyond the end of the list, throw an exception as that should not occur.
*/
if (index < list.size) list.add(index, value)
else if (index == list.size || index == list.size + 1) list.add(value)
else throw IndexOutOfBoundsException()
}
}

View File

@@ -1,25 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.write.transform
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.dataflow.transform.ColumnNameMapper
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import jakarta.inject.Singleton
@Singleton
class SnowflakeColumnNameMapper(
private val catalogInfo: TableCatalog,
private val snowflakeConfiguration: SnowflakeConfiguration,
) : ColumnNameMapper {
override fun getMappedColumnName(stream: DestinationStream, columnName: String): String {
if (snowflakeConfiguration.legacyRawTablesOnly == true) {
return columnName
} else {
return catalogInfo.getMappedColumnName(stream, columnName)!!
}
}
}

View File

@@ -7,9 +7,11 @@ package io.airbyte.integrations.destination.snowflake.component
import io.airbyte.cdk.load.component.TableOperationsFixtures
import io.airbyte.cdk.load.component.TableOperationsSuite
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.TableSchemaFactory
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.idTestWithCdcMapping
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.testMapping
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.idTestWithCdcMapping
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.testMapping
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeTestTableOperationsClient
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.parallel.Execution
@@ -20,6 +22,7 @@ import org.junit.jupiter.api.parallel.ExecutionMode
class SnowflakeTableOperationsTest(
override val client: SnowflakeAirbyteClient,
override val testClient: SnowflakeTestTableOperationsClient,
override val schemaFactory: TableSchemaFactory,
) : TableOperationsSuite {
override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() }

View File

@@ -9,13 +9,15 @@ import io.airbyte.cdk.load.component.TableSchemaEvolutionFixtures
import io.airbyte.cdk.load.component.TableSchemaEvolutionSuite
import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.TableSchemaFactory
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.util.serializeToString
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.allTypesColumnNameMapping
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.allTypesTableSchema
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.idAndTestMapping
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.testMapping
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.allTypesColumnNameMapping
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.allTypesTableSchema
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.idAndTestMapping
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.testMapping
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeTestTableOperationsClient
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.parallel.Execution
@@ -27,6 +29,7 @@ class SnowflakeTableSchemaEvolutionTest(
override val client: SnowflakeAirbyteClient,
override val opsClient: SnowflakeAirbyteClient,
override val testClient: SnowflakeTestTableOperationsClient,
override val schemaFactory: TableSchemaFactory,
) : TableSchemaEvolutionSuite {
override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() }

View File

@@ -2,7 +2,7 @@
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.component
package io.airbyte.integrations.destination.snowflake.component.config
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.component.TableOperationsFixtures
@@ -33,6 +33,8 @@ object SnowflakeComponentTestFixtures {
"TIME_NTZ" to ColumnType("TIME", true),
"ARRAY" to ColumnType("ARRAY", true),
"OBJECT" to ColumnType("OBJECT", true),
"UNION" to ColumnType("VARIANT", true),
"LEGACY_UNION" to ColumnType("VARIANT", true),
"UNKNOWN" to ColumnType("VARIANT", true),
)
)

View File

@@ -2,7 +2,7 @@
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.component
package io.airbyte.integrations.destination.snowflake.component.config
import io.airbyte.cdk.load.component.config.TestConfigLoader.loadTestConfig
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration

View File

@@ -2,22 +2,24 @@
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.component
package io.airbyte.integrations.destination.snowflake.component.config
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.component.TestTableOperationsClient
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.dataflow.state.PartitionKey
import io.airbyte.cdk.load.dataflow.transform.RecordDTO
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.util.Jsons
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.client.execute
import io.airbyte.integrations.destination.snowflake.dataflow.SnowflakeAggregate
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
import io.airbyte.integrations.destination.snowflake.sql.andLog
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
import io.micronaut.context.annotation.Requires
import jakarta.inject.Singleton
import java.time.format.DateTimeFormatter
@@ -29,25 +31,40 @@ import net.snowflake.client.jdbc.SnowflakeTimestampWithTimezone
class SnowflakeTestTableOperationsClient(
private val client: SnowflakeAirbyteClient,
private val dataSource: DataSource,
private val snowflakeSqlNameUtils: SnowflakeSqlNameUtils,
private val snowflakeColumnUtils: SnowflakeColumnUtils,
private val sqlGenerator: SnowflakeDirectLoadSqlGenerator,
private val snowflakeConfiguration: SnowflakeConfiguration,
private val columnManager: SnowflakeColumnManager,
private val snowflakeRecordFormatter: SnowflakeRecordFormatter,
) : TestTableOperationsClient {
override suspend fun dropNamespace(namespace: String) {
dataSource.execute(
"DROP SCHEMA IF EXISTS ${snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)}".andLog()
"DROP SCHEMA IF EXISTS ${sqlGenerator.fullyQualifiedNamespace(namespace)}".andLog()
)
}
override suspend fun insertRecords(table: TableName, records: List<Map<String, AirbyteValue>>) {
// TODO: we should just pass a proper column schema
// Since we don't pass in a proper column schema, we have to recreate one here
// Fetch the columns and filter out the meta columns so we're just looking at user columns
val columnTypes =
client.describeTable(table).filterNot {
columnManager.getMetaColumnNames().contains(it.key)
}
val columnSchema =
io.airbyte.cdk.load.schema.model.ColumnSchema(
inputToFinalColumnNames = columnTypes.keys.associateWith { it },
finalSchema = columnTypes.mapValues { (_, _) -> ColumnType("", true) },
inputSchema = emptyMap() // Not needed for insert buffer
)
val a =
SnowflakeAggregate(
SnowflakeInsertBuffer(
table,
client.describeTable(table),
client,
snowflakeConfiguration,
snowflakeColumnUtils,
tableName = table,
snowflakeClient = client,
snowflakeConfiguration = snowflakeConfiguration,
columnSchema = columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
)
)
records.forEach { a.accept(RecordDTO(it, PartitionKey(""), 0, 0)) }

View File

@@ -7,11 +7,11 @@ package io.airbyte.integrations.destination.snowflake.write
import io.airbyte.cdk.load.test.util.DestinationCleaner
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier
import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory
import io.airbyte.integrations.destination.snowflake.sql.STAGE_NAME_PREFIX
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
import io.airbyte.integrations.destination.snowflake.sql.quote
import java.nio.file.Files
import java.sql.Connection

View File

@@ -12,16 +12,22 @@ import io.airbyte.cdk.load.data.ObjectValue
import io.airbyte.cdk.load.data.json.toAirbyteValue
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator
import io.airbyte.cdk.load.test.util.DestinationDataDumper
import io.airbyte.cdk.load.test.util.OutputRecord
import io.airbyte.cdk.load.util.UUIDGenerator
import io.airbyte.cdk.load.util.deserializeToNode
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
import io.airbyte.integrations.destination.snowflake.db.SnowflakeFinalTableNameGenerator
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeTableSchemaMapper
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
import io.airbyte.integrations.destination.snowflake.sql.sqlEscape
import java.math.BigDecimal
import java.sql.Date
import java.sql.Time
import java.sql.Timestamp
import net.snowflake.client.jdbc.SnowflakeTimestampWithTimezone
private val AIRBYTE_META_COLUMNS = Meta.COLUMN_NAMES + setOf(CDC_DELETED_AT_COLUMN)
@@ -34,8 +40,14 @@ class SnowflakeDataDumper(
stream: DestinationStream
): List<OutputRecord> {
val config = configProvider(spec)
val sqlUtils = SnowflakeSqlNameUtils(config)
val snowflakeFinalTableNameGenerator = SnowflakeFinalTableNameGenerator(config)
val snowflakeFinalTableNameGenerator =
SnowflakeTableSchemaMapper(
config = config,
tempTableNameGenerator = DefaultTempTableNameGenerator(),
)
val snowflakeColumnManager = SnowflakeColumnManager(config)
val sqlGenerator =
SnowflakeDirectLoadSqlGenerator(UUIDGenerator(), config, snowflakeColumnManager)
val dataSource =
SnowflakeBeanFactory()
.snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY")
@@ -46,7 +58,7 @@ class SnowflakeDataDumper(
ds.connection.use { connection ->
val statement = connection.createStatement()
val tableName =
snowflakeFinalTableNameGenerator.getTableName(stream.mappedDescriptor)
snowflakeFinalTableNameGenerator.toFinalTableName(stream.mappedDescriptor)
// First check if the table exists
val tableExistsQuery =
@@ -69,7 +81,7 @@ class SnowflakeDataDumper(
val resultSet =
statement.executeQuery(
"SELECT * FROM ${sqlUtils.fullyQualifiedName(tableName)}"
"SELECT * FROM ${sqlGenerator.fullyQualifiedName(tableName)}"
)
while (resultSet.next()) {
@@ -143,10 +155,10 @@ class SnowflakeDataDumper(
private fun convertValue(value: Any?): Any? =
when (value) {
is BigDecimal -> value.toBigInteger()
is java.sql.Date -> value.toLocalDate()
is Date -> value.toLocalDate()
is SnowflakeTimestampWithTimezone -> value.toZonedDateTime()
is java.sql.Time -> value.toLocalTime()
is java.sql.Timestamp -> value.toLocalDateTime()
is Time -> value.toLocalTime()
is Timestamp -> value.toLocalDateTime()
else -> value
}
}

View File

@@ -15,7 +15,7 @@ import io.airbyte.cdk.load.dataflow.transform.ValidationResult
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.test.util.ExpectedRecordMapper
import io.airbyte.cdk.load.test.util.OutputRecord
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.write.transform.SnowflakeValueCoercer
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Change

View File

@@ -5,7 +5,7 @@
package io.airbyte.integrations.destination.snowflake.write
import io.airbyte.cdk.load.test.util.NameMapper
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
class SnowflakeNameMapper : NameMapper {
override fun mapFieldName(path: List<String>): List<String> =

View File

@@ -9,13 +9,16 @@ import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.data.json.toAirbyteValue
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator
import io.airbyte.cdk.load.test.util.DestinationDataDumper
import io.airbyte.cdk.load.test.util.OutputRecord
import io.airbyte.cdk.load.util.UUIDGenerator
import io.airbyte.cdk.load.util.deserializeToNode
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
import io.airbyte.integrations.destination.snowflake.db.SnowflakeFinalTableNameGenerator
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeTableSchemaMapper
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
class SnowflakeRawDataDumper(
private val configProvider: (ConfigurationSpecification) -> SnowflakeConfiguration
@@ -27,8 +30,18 @@ class SnowflakeRawDataDumper(
val output = mutableListOf<OutputRecord>()
val config = configProvider(spec)
val sqlUtils = SnowflakeSqlNameUtils(config)
val snowflakeFinalTableNameGenerator = SnowflakeFinalTableNameGenerator(config)
val snowflakeColumnManager = SnowflakeColumnManager(config)
val sqlGenerator =
SnowflakeDirectLoadSqlGenerator(
UUIDGenerator(),
config,
snowflakeColumnManager,
)
val snowflakeFinalTableNameGenerator =
SnowflakeTableSchemaMapper(
config = config,
tempTableNameGenerator = DefaultTempTableNameGenerator(),
)
val dataSource =
SnowflakeBeanFactory()
.snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY")
@@ -37,11 +50,11 @@ class SnowflakeRawDataDumper(
ds.connection.use { connection ->
val statement = connection.createStatement()
val tableName =
snowflakeFinalTableNameGenerator.getTableName(stream.mappedDescriptor)
snowflakeFinalTableNameGenerator.toFinalTableName(stream.mappedDescriptor)
val resultSet =
statement.executeQuery(
"SELECT * FROM ${sqlUtils.fullyQualifiedName(tableName)}"
"SELECT * FROM ${sqlGenerator.fullyQualifiedName(tableName)}"
)
while (resultSet.next()) {

View File

@@ -6,7 +6,7 @@ package io.airbyte.integrations.destination.snowflake
import com.zaxxer.hikari.HikariConfig
import com.zaxxer.hikari.HikariDataSource
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode
import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration

View File

@@ -5,11 +5,9 @@
package io.airbyte.integrations.destination.snowflake.check
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.DEFAULT_COLUMNS
import io.airbyte.integrations.destination.snowflake.sql.RAW_DATA_COLUMN
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.every
@@ -23,47 +21,31 @@ internal class SnowflakeCheckerTest {
@ParameterizedTest
@ValueSource(booleans = [true, false])
fun testSuccessfulCheck(isLegacyRawTablesOnly: Boolean) {
val defaultColumnsMap =
if (isLegacyRawTablesOnly) {
linkedMapOf<String, String>().also { map ->
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
map[it.columnName] = it.columnType
}
}
} else {
linkedMapOf<String, String>().also { map ->
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
map[it.columnName.toSnowflakeCompatibleName()] = it.columnType
}
}
}
val defaultColumns = defaultColumnsMap.keys.toMutableList()
val snowflakeAirbyteClient: SnowflakeAirbyteClient =
mockk(relaxed = true) {
coEvery { countTable(any()) } returns 1L
coEvery { describeTable(any()) } returns defaultColumnsMap
}
mockk(relaxed = true) { coEvery { countTable(any()) } returns 1L }
val testSchema = "test-schema"
val snowflakeConfiguration: SnowflakeConfiguration = mockk {
every { schema } returns testSchema
every { legacyRawTablesOnly } returns isLegacyRawTablesOnly
}
val snowflakeColumnUtils =
mockk<SnowflakeColumnUtils>(relaxUnitFun = true) {
every { getFormattedDefaultColumnNames(any()) } returns defaultColumns
}
val columnManager: SnowflakeColumnManager = SnowflakeColumnManager(snowflakeConfiguration)
val checker =
SnowflakeChecker(
snowflakeAirbyteClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils,
columnManager = columnManager,
)
checker.check()
coVerify(exactly = 1) {
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
if (isLegacyRawTablesOnly) {
snowflakeAirbyteClient.createNamespace(testSchema)
} else {
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
}
}
coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) }
coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) }
@@ -72,48 +54,32 @@ internal class SnowflakeCheckerTest {
@ParameterizedTest
@ValueSource(booleans = [true, false])
fun testUnsuccessfulCheck(isLegacyRawTablesOnly: Boolean) {
val defaultColumnsMap =
if (isLegacyRawTablesOnly) {
linkedMapOf<String, String>().also { map ->
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
map[it.columnName] = it.columnType
}
}
} else {
linkedMapOf<String, String>().also { map ->
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
map[it.columnName.toSnowflakeCompatibleName()] = it.columnType
}
}
}
val defaultColumns = defaultColumnsMap.keys.toMutableList()
val snowflakeAirbyteClient: SnowflakeAirbyteClient =
mockk(relaxed = true) {
coEvery { countTable(any()) } returns 0L
coEvery { describeTable(any()) } returns defaultColumnsMap
}
mockk(relaxed = true) { coEvery { countTable(any()) } returns 0L }
val testSchema = "test-schema"
val snowflakeConfiguration: SnowflakeConfiguration = mockk {
every { schema } returns testSchema
every { legacyRawTablesOnly } returns isLegacyRawTablesOnly
}
val snowflakeColumnUtils =
mockk<SnowflakeColumnUtils>(relaxUnitFun = true) {
every { getFormattedDefaultColumnNames(any()) } returns defaultColumns
}
val columnManager: SnowflakeColumnManager = SnowflakeColumnManager(snowflakeConfiguration)
val checker =
SnowflakeChecker(
snowflakeAirbyteClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils,
columnManager = columnManager,
)
assertThrows<IllegalArgumentException> { checker.check() }
coVerify(exactly = 1) {
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
if (isLegacyRawTablesOnly) {
snowflakeAirbyteClient.createNamespace(testSchema)
} else {
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
}
}
coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) }
coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) }

View File

@@ -6,24 +6,16 @@ package io.airbyte.integrations.destination.snowflake.client
import io.airbyte.cdk.ConfigErrorException
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.command.NamespaceMapper
import io.airbyte.cdk.load.command.Overwrite
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.config.NamespaceDefinitionType
import io.airbyte.cdk.load.data.AirbyteType
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TableName
import io.airbyte.integrations.destination.snowflake.db.ColumnDefinition
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS
import io.airbyte.integrations.destination.snowflake.sql.ColumnAndType
import io.airbyte.integrations.destination.snowflake.sql.DEFAULT_COLUMNS
import io.airbyte.integrations.destination.snowflake.sql.QUOTE
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
import io.mockk.Runs
import io.mockk.every
@@ -49,31 +41,18 @@ internal class SnowflakeAirbyteClientTest {
private lateinit var client: SnowflakeAirbyteClient
private lateinit var dataSource: DataSource
private lateinit var sqlGenerator: SnowflakeDirectLoadSqlGenerator
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
private lateinit var columnManager: SnowflakeColumnManager
@BeforeEach
fun setup() {
dataSource = mockk()
sqlGenerator = mockk(relaxed = true)
snowflakeColumnUtils =
mockk(relaxed = true) {
every { formatColumnName(any()) } answers
{
firstArg<String>().toSnowflakeCompatibleName()
}
every { getFormattedDefaultColumnNames(any()) } returns
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() }
}
snowflakeConfiguration =
mockk(relaxed = true) { every { database } returns "test_database" }
columnManager = mockk(relaxed = true)
client =
SnowflakeAirbyteClient(
dataSource,
sqlGenerator,
snowflakeColumnUtils,
snowflakeConfiguration
)
SnowflakeAirbyteClient(dataSource, sqlGenerator, snowflakeConfiguration, columnManager)
}
@Test
@@ -231,7 +210,7 @@ internal class SnowflakeAirbyteClientTest {
@Test
fun testCreateTable() {
val columnNameMapping = mockk<ColumnNameMapping>(relaxed = true)
val stream = mockk<DestinationStream>()
val stream = mockk<DestinationStream>(relaxed = true)
val tableName = TableName(namespace = "namespace", name = "name")
val resultSet = mockk<ResultSet>(relaxed = true)
val statement =
@@ -254,9 +233,7 @@ internal class SnowflakeAirbyteClientTest {
columnNameMapping = columnNameMapping,
replace = true,
)
verify(exactly = 1) {
sqlGenerator.createTable(stream, tableName, columnNameMapping, true)
}
verify(exactly = 1) { sqlGenerator.createTable(tableName, any(), true) }
verify(exactly = 1) { sqlGenerator.createSnowflakeStage(tableName) }
verify(exactly = 2) { mockConnection.close() }
}
@@ -288,7 +265,7 @@ internal class SnowflakeAirbyteClientTest {
targetTableName = destinationTableName,
)
verify(exactly = 1) {
sqlGenerator.copyTable(columnNameMapping, sourceTableName, destinationTableName)
sqlGenerator.copyTable(any<Set<String>>(), sourceTableName, destinationTableName)
}
verify(exactly = 1) { mockConnection.close() }
}
@@ -299,7 +276,7 @@ internal class SnowflakeAirbyteClientTest {
val columnNameMapping = mockk<ColumnNameMapping>(relaxed = true)
val sourceTableName = TableName(namespace = "namespace", name = "source")
val destinationTableName = TableName(namespace = "namespace", name = "destination")
val stream = mockk<DestinationStream>()
val stream = mockk<DestinationStream>(relaxed = true)
val resultSet = mockk<ResultSet>(relaxed = true)
val statement =
mockk<Statement> {
@@ -322,12 +299,7 @@ internal class SnowflakeAirbyteClientTest {
targetTableName = destinationTableName,
)
verify(exactly = 1) {
sqlGenerator.upsertTable(
stream,
columnNameMapping,
sourceTableName,
destinationTableName
)
sqlGenerator.upsertTable(any(), sourceTableName, destinationTableName)
}
verify(exactly = 1) { mockConnection.close() }
}
@@ -379,7 +351,7 @@ internal class SnowflakeAirbyteClientTest {
}
every { dataSource.connection } returns mockConnection
every { snowflakeColumnUtils.getGenerationIdColumnName() } returns generationIdColumnName
every { columnManager.getGenerationIdColumnName() } returns generationIdColumnName
every { sqlGenerator.getGenerationId(tableName) } returns
"SELECT $generationIdColumnName FROM ${tableName.toPrettyString(QUOTE)}"
@@ -501,8 +473,8 @@ internal class SnowflakeAirbyteClientTest {
every { dataSource.connection } returns mockConnection
runBlocking {
client.copyFromStage(tableName, "test.csv.gz")
verify(exactly = 1) { sqlGenerator.copyFromStage(tableName, "test.csv.gz") }
client.copyFromStage(tableName, "test.csv.gz", listOf())
verify(exactly = 1) { sqlGenerator.copyFromStage(tableName, "test.csv.gz", listOf()) }
verify(exactly = 1) { mockConnection.close() }
}
}
@@ -556,7 +528,7 @@ internal class SnowflakeAirbyteClientTest {
"COL1" andThen
COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName() andThen
"COL2"
every { resultSet.getString("type") } returns "VARCHAR(255)" andThen "NUMBER(38,0)"
every { resultSet.getString("type") } returns "VARCHAR(255)" andThen "NUMBER"
every { resultSet.getString("null?") } returns "Y" andThen "N" andThen "N"
val statement =
@@ -571,6 +543,10 @@ internal class SnowflakeAirbyteClientTest {
every { dataSource.connection } returns connection
// Mock the columnManager to return the correct set of meta columns
every { columnManager.getMetaColumnNames() } returns
setOf(COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName())
val result = client.getColumnsFromDb(tableName)
val expectedColumns =
@@ -582,81 +558,6 @@ internal class SnowflakeAirbyteClientTest {
assertEquals(expectedColumns, result)
}
@Test
fun `getColumnsFromStream should return correct column definitions`() {
val schema = mockk<AirbyteType>()
val stream =
DestinationStream(
unmappedNamespace = "test_namespace",
unmappedName = "test_stream",
importType = Overwrite,
schema = schema,
generationId = 1,
minimumGenerationId = 1,
syncId = 1,
namespaceMapper = NamespaceMapper(NamespaceDefinitionType.DESTINATION)
)
val columnNameMapping =
ColumnNameMapping(
mapOf(
"col1" to "COL1_MAPPED",
"col2" to "COL2_MAPPED",
)
)
val col1FieldType = mockk<FieldType>()
every { col1FieldType.type } returns mockk()
val col2FieldType = mockk<FieldType>()
every { col2FieldType.type } returns mockk()
every { schema.asColumns() } returns
linkedMapOf("col1" to col1FieldType, "col2" to col2FieldType)
every { snowflakeColumnUtils.toDialectType(col1FieldType.type) } returns "VARCHAR(255)"
every { snowflakeColumnUtils.toDialectType(col2FieldType.type) } returns "NUMBER(38,0)"
every { snowflakeColumnUtils.columnsAndTypes(any(), any()) } returns
listOf(ColumnAndType("COL1_MAPPED", "VARCHAR"), ColumnAndType("COL2_MAPPED", "NUMBER"))
every { snowflakeColumnUtils.formatColumnName(any(), false) } answers
{
firstArg<String>().toSnowflakeCompatibleName()
}
val result = client.getColumnsFromStream(stream, columnNameMapping)
val expectedColumns =
mapOf(
"COL1_MAPPED" to ColumnType("VARCHAR", true),
"COL2_MAPPED" to ColumnType("NUMBER", true),
)
assertEquals(expectedColumns, result)
}
@Test
fun `generateSchemaChanges should correctly identify changes`() {
val columnsInDb =
setOf(
ColumnDefinition("COL1", "VARCHAR"),
ColumnDefinition("COL2", "NUMBER"),
ColumnDefinition("COL3", "BOOLEAN")
)
val columnsInStream =
setOf(
ColumnDefinition("COL1", "VARCHAR"), // Unchanged
ColumnDefinition("COL3", "TEXT"), // Modified
ColumnDefinition("COL4", "DATE") // Added
)
val (added, deleted, modified) = client.generateSchemaChanges(columnsInDb, columnsInStream)
assertEquals(1, added.size)
assertEquals("COL4", added.first().name)
assertEquals(1, deleted.size)
assertEquals("COL2", deleted.first().name)
assertEquals(1, modified.size)
assertEquals("COL3", modified.first().name)
}
@Test
fun testCreateNamespaceWithNetworkFailure() {
val namespace = "test_namespace"

View File

@@ -4,14 +4,19 @@
package io.airbyte.integrations.destination.snowflake.dataflow
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.command.DestinationStream.Descriptor
import io.airbyte.cdk.load.dataflow.aggregate.StoreKey
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.write.StreamStateStore
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRawRecordFormatter
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter
import io.mockk.every
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertEquals
@@ -22,27 +27,33 @@ internal class SnowflakeAggregateFactoryTest {
@Test
fun testCreatingAggregateWithRawBuffer() {
val descriptor = Descriptor(namespace = "namespace", name = "name")
val directLoadTableExecutionConfig =
DirectLoadTableExecutionConfig(
tableName =
TableName(
namespace = descriptor.namespace!!,
name = descriptor.name,
)
val tableName =
TableName(
namespace = descriptor.namespace!!,
name = descriptor.name,
)
val directLoadTableExecutionConfig = DirectLoadTableExecutionConfig(tableName = tableName)
val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name)
val streamStore = StreamStateStore<DirectLoadTableExecutionConfig>()
streamStore.put(descriptor, directLoadTableExecutionConfig)
streamStore.put(key, directLoadTableExecutionConfig)
val stream = mockk<DestinationStream>(relaxed = true)
val catalog = mockk<DestinationCatalog> { every { getStream(key) } returns stream }
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val snowflakeConfiguration =
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns true }
val snowflakeColumnUtils = mockk<SnowflakeColumnUtils>(relaxed = true)
val columnManager = SnowflakeColumnManager(snowflakeConfiguration)
val snowflakeRecordFormatter: SnowflakeRecordFormatter = SnowflakeRawRecordFormatter()
val factory =
SnowflakeAggregateFactory(
snowflakeClient = snowflakeClient,
streamStateStore = streamStore,
snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils,
catalog = catalog,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
)
val aggregate = factory.create(key)
assertNotNull(aggregate)
@@ -52,26 +63,33 @@ internal class SnowflakeAggregateFactoryTest {
@Test
fun testCreatingAggregateWithStagingBuffer() {
val descriptor = Descriptor(namespace = "namespace", name = "name")
val directLoadTableExecutionConfig =
DirectLoadTableExecutionConfig(
tableName =
TableName(
namespace = descriptor.namespace!!,
name = descriptor.name,
)
val tableName =
TableName(
namespace = descriptor.namespace!!,
name = descriptor.name,
)
val directLoadTableExecutionConfig = DirectLoadTableExecutionConfig(tableName = tableName)
val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name)
val streamStore = StreamStateStore<DirectLoadTableExecutionConfig>()
streamStore.put(descriptor, directLoadTableExecutionConfig)
streamStore.put(key, directLoadTableExecutionConfig)
val stream = mockk<DestinationStream>(relaxed = true)
val catalog = mockk<DestinationCatalog> { every { getStream(key) } returns stream }
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val snowflakeConfiguration = mockk<SnowflakeConfiguration>(relaxed = true)
val snowflakeColumnUtils = mockk<SnowflakeColumnUtils>(relaxed = true)
val snowflakeConfiguration =
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns false }
val columnManager = SnowflakeColumnManager(snowflakeConfiguration)
val snowflakeRecordFormatter: SnowflakeRecordFormatter = SnowflakeSchemaRecordFormatter()
val factory =
SnowflakeAggregateFactory(
snowflakeClient = snowflakeClient,
streamStateStore = streamStore,
snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils,
catalog = catalog,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
)
val aggregate = factory.create(key)
assertNotNull(aggregate)

View File

@@ -1,23 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
import io.mockk.every
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
internal class SnowflakeColumnNameGeneratorTest {
@Test
fun testGetColumnName() {
val column = "test-column"
val generator =
SnowflakeColumnNameGenerator(mockk { every { legacyRawTablesOnly } returns false })
val columnName = generator.getColumnName(column)
assertEquals(column.toSnowflakeCompatibleName(), columnName.displayName)
assertEquals(column.toSnowflakeCompatibleName(), columnName.canonicalName)
}
}

View File

@@ -1,72 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.mockk.every
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
internal class SnowflakeFinalTableNameGeneratorTest {
@Test
fun testGetTableNameWithInternalNamespace() {
val configuration =
mockk<SnowflakeConfiguration> {
every { internalTableSchema } returns "test-internal-namespace"
every { legacyRawTablesOnly } returns true
}
val generator = SnowflakeFinalTableNameGenerator(config = configuration)
val streamName = "test-stream-name"
val streamNamespace = "test-stream-namespace"
val streamDescriptor =
mockk<DestinationStream.Descriptor> {
every { namespace } returns streamNamespace
every { name } returns streamName
}
val tableName = generator.getTableName(streamDescriptor)
assertEquals("test-stream-namespace_raw__stream_test-stream-name", tableName.name)
assertEquals("test-internal-namespace", tableName.namespace)
}
@Test
fun testGetTableNameWithNamespace() {
val configuration =
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns false }
val generator = SnowflakeFinalTableNameGenerator(config = configuration)
val streamName = "test-stream-name"
val streamNamespace = "test-stream-namespace"
val streamDescriptor =
mockk<DestinationStream.Descriptor> {
every { namespace } returns streamNamespace
every { name } returns streamName
}
val tableName = generator.getTableName(streamDescriptor)
assertEquals("TEST-STREAM-NAME", tableName.name)
assertEquals("TEST-STREAM-NAMESPACE", tableName.namespace)
}
@Test
fun testGetTableNameWithDefaultNamespace() {
val defaultNamespace = "test-default-namespace"
val configuration =
mockk<SnowflakeConfiguration> {
every { schema } returns defaultNamespace
every { legacyRawTablesOnly } returns false
}
val generator = SnowflakeFinalTableNameGenerator(config = configuration)
val streamName = "test-stream-name"
val streamDescriptor =
mockk<DestinationStream.Descriptor> {
every { namespace } returns null
every { name } returns streamName
}
val tableName = generator.getTableName(streamDescriptor)
assertEquals("TEST-STREAM-NAME", tableName.name)
assertEquals("TEST-DEFAULT-NAMESPACE", tableName.namespace)
}
}

View File

@@ -1,27 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.CsvSource
internal class SnowflakeNameGeneratorsTest {
@ParameterizedTest
@CsvSource(
value =
[
"test-name,TEST-NAME",
"1-test-name,1-TEST-NAME",
"test-name!!!,TEST-NAME!!!",
"test\${name,TEST__NAME",
"test\"name,TEST\"\"NAME",
]
)
fun testToSnowflakeCompatibleName(name: String, expected: String) {
assertEquals(expected, name.toSnowflakeCompatibleName())
}
}

View File

@@ -1,358 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
import com.fasterxml.jackson.databind.JsonNode
import io.airbyte.cdk.load.data.ArrayType
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
import io.airbyte.cdk.load.data.BooleanType
import io.airbyte.cdk.load.data.DateType
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerType
import io.airbyte.cdk.load.data.NumberType
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
import io.airbyte.cdk.load.data.UnionType
import io.airbyte.cdk.load.data.UnknownType
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA
import io.airbyte.cdk.load.orchestration.db.ColumnNameGenerator
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.integrations.destination.snowflake.db.SnowflakeColumnNameGenerator
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.mockk.every
import io.mockk.mockk
import kotlin.collections.LinkedHashMap
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.CsvSource
internal class SnowflakeColumnUtilsTest {
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
private lateinit var snowflakeColumnNameGenerator: SnowflakeColumnNameGenerator
@BeforeEach
fun setup() {
snowflakeConfiguration = mockk(relaxed = true)
snowflakeColumnNameGenerator =
mockk(relaxed = true) {
every { getColumnName(any()) } answers
{
val displayName =
if (snowflakeConfiguration.legacyRawTablesOnly) firstArg<String>()
else firstArg<String>().toSnowflakeCompatibleName()
val canonicalName =
if (snowflakeConfiguration.legacyRawTablesOnly) firstArg<String>()
else firstArg<String>().toSnowflakeCompatibleName()
ColumnNameGenerator.ColumnName(
displayName = displayName,
canonicalName = canonicalName,
)
}
}
snowflakeColumnUtils =
SnowflakeColumnUtils(snowflakeConfiguration, snowflakeColumnNameGenerator)
}
@Test
fun testDefaultColumns() {
val expectedDefaultColumns = DEFAULT_COLUMNS
assertEquals(expectedDefaultColumns, snowflakeColumnUtils.defaultColumns())
}
@Test
fun testDefaultRawColumns() {
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val expectedDefaultColumns = DEFAULT_COLUMNS + RAW_COLUMNS
assertEquals(expectedDefaultColumns, snowflakeColumnUtils.defaultColumns())
}
@Test
fun testGetFormattedDefaultColumnNames() {
val expectedDefaultColumnNames =
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() }
val defaultColumnNames = snowflakeColumnUtils.getFormattedDefaultColumnNames()
assertEquals(expectedDefaultColumnNames, defaultColumnNames)
}
@Test
fun testGetFormattedDefaultColumnNamesQuoted() {
val expectedDefaultColumnNames =
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName().quote() }
val defaultColumnNames = snowflakeColumnUtils.getFormattedDefaultColumnNames(true)
assertEquals(expectedDefaultColumnNames, defaultColumnNames)
}
@Test
fun testGetColumnName() {
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val columnNames = snowflakeColumnUtils.getColumnNames(columnNameMapping)
val expectedColumnNames =
(DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() } + listOf("actual"))
.joinToString(",") { it.quote() }
assertEquals(expectedColumnNames, columnNames)
}
@Test
fun testGetRawColumnName() {
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val columnNames = snowflakeColumnUtils.getColumnNames(columnNameMapping)
val expectedColumnNames =
(DEFAULT_COLUMNS.map { it.columnName } + RAW_COLUMNS.map { it.columnName })
.joinToString(",") { it.quote() }
assertEquals(expectedColumnNames, columnNames)
}
@Test
fun testGetRawFormattedColumnNames() {
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val schemaColumns =
mapOf(
"column_one" to FieldType(StringType, true),
"column_two" to FieldType(IntegerType, true),
"original" to FieldType(StringType, true),
CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true)
)
val expectedColumnNames =
DEFAULT_COLUMNS.map { it.columnName.quote() } +
RAW_COLUMNS.map { it.columnName.quote() }
val columnNames =
snowflakeColumnUtils.getFormattedColumnNames(
columns = schemaColumns,
columnNameMapping = columnNameMapping
)
assertEquals(expectedColumnNames.size, columnNames.size)
assertEquals(expectedColumnNames.sorted(), columnNames.sorted())
}
@Test
fun testGetFormattedColumnNames() {
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val schemaColumns =
mapOf(
"column_one" to FieldType(StringType, true),
"column_two" to FieldType(IntegerType, true),
"original" to FieldType(StringType, true),
CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true)
)
val expectedColumnNames =
listOf(
"actual",
"column_one",
"column_two",
CDC_DELETED_AT_COLUMN,
)
.map { it.quote() } +
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName().quote() }
val columnNames =
snowflakeColumnUtils.getFormattedColumnNames(
columns = schemaColumns,
columnNameMapping = columnNameMapping
)
assertEquals(expectedColumnNames.size, columnNames.size)
assertEquals(expectedColumnNames.sorted(), columnNames.sorted())
}
@Test
fun testGetFormattedColumnNamesNoQuotes() {
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val schemaColumns =
mapOf(
"column_one" to FieldType(StringType, true),
"column_two" to FieldType(IntegerType, true),
"original" to FieldType(StringType, true),
CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true)
)
val expectedColumnNames =
listOf(
"actual",
"column_one",
"column_two",
CDC_DELETED_AT_COLUMN,
) + DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() }
val columnNames =
snowflakeColumnUtils.getFormattedColumnNames(
columns = schemaColumns,
columnNameMapping = columnNameMapping,
quote = false
)
assertEquals(expectedColumnNames.size, columnNames.size)
assertEquals(expectedColumnNames.sorted(), columnNames.sorted())
}
@Test
fun testGeneratingRawTableColumnsAndTypesNoColumnMapping() {
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val columns =
snowflakeColumnUtils.columnsAndTypes(
columns = emptyMap(),
columnNameMapping = ColumnNameMapping(emptyMap())
)
assertEquals(DEFAULT_COLUMNS.size + RAW_COLUMNS.size, columns.size)
assertEquals(
"${SnowflakeDataType.VARIANT.typeName} $NOT_NULL",
columns.find { it.columnName == RAW_DATA_COLUMN.columnName }?.columnType
)
}
@Test
fun testGeneratingColumnsAndTypesNoColumnMapping() {
val columnName = "test-column"
val fieldType = FieldType(StringType, false)
val declaredColumns = mapOf(columnName to fieldType)
val columns =
snowflakeColumnUtils.columnsAndTypes(
columns = declaredColumns,
columnNameMapping = ColumnNameMapping(emptyMap())
)
assertEquals(DEFAULT_COLUMNS.size + 1, columns.size)
assertEquals(
"${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL",
columns.find { it.columnName == columnName }?.columnType
)
}
@Test
fun testGeneratingColumnsAndTypesWithColumnMapping() {
val columnName = "test-column"
val mappedColumnName = "mapped-column-name"
val fieldType = FieldType(StringType, false)
val declaredColumns = mapOf(columnName to fieldType)
val columnNameMapping = ColumnNameMapping(mapOf(columnName to mappedColumnName))
val columns =
snowflakeColumnUtils.columnsAndTypes(
columns = declaredColumns,
columnNameMapping = columnNameMapping
)
assertEquals(DEFAULT_COLUMNS.size + 1, columns.size)
assertEquals(
"${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL",
columns.find { it.columnName == mappedColumnName }?.columnType
)
}
@Test
fun testToDialectType() {
assertEquals(
SnowflakeDataType.BOOLEAN.typeName,
snowflakeColumnUtils.toDialectType(BooleanType)
)
assertEquals(SnowflakeDataType.DATE.typeName, snowflakeColumnUtils.toDialectType(DateType))
assertEquals(
SnowflakeDataType.NUMBER.typeName,
snowflakeColumnUtils.toDialectType(IntegerType)
)
assertEquals(
SnowflakeDataType.FLOAT.typeName,
snowflakeColumnUtils.toDialectType(NumberType)
)
assertEquals(
SnowflakeDataType.VARCHAR.typeName,
snowflakeColumnUtils.toDialectType(StringType)
)
assertEquals(
SnowflakeDataType.VARCHAR.typeName,
snowflakeColumnUtils.toDialectType(TimeTypeWithTimezone)
)
assertEquals(
SnowflakeDataType.TIME.typeName,
snowflakeColumnUtils.toDialectType(TimeTypeWithoutTimezone)
)
assertEquals(
SnowflakeDataType.TIMESTAMP_TZ.typeName,
snowflakeColumnUtils.toDialectType(TimestampTypeWithTimezone)
)
assertEquals(
SnowflakeDataType.TIMESTAMP_NTZ.typeName,
snowflakeColumnUtils.toDialectType(TimestampTypeWithoutTimezone)
)
assertEquals(
SnowflakeDataType.ARRAY.typeName,
snowflakeColumnUtils.toDialectType(ArrayType(items = FieldType(StringType, false)))
)
assertEquals(
SnowflakeDataType.ARRAY.typeName,
snowflakeColumnUtils.toDialectType(ArrayTypeWithoutSchema)
)
assertEquals(
SnowflakeDataType.OBJECT.typeName,
snowflakeColumnUtils.toDialectType(
ObjectType(
properties = LinkedHashMap(),
additionalProperties = false,
)
)
)
assertEquals(
SnowflakeDataType.OBJECT.typeName,
snowflakeColumnUtils.toDialectType(ObjectTypeWithEmptySchema)
)
assertEquals(
SnowflakeDataType.OBJECT.typeName,
snowflakeColumnUtils.toDialectType(ObjectTypeWithoutSchema)
)
assertEquals(
SnowflakeDataType.VARIANT.typeName,
snowflakeColumnUtils.toDialectType(
UnionType(
options = setOf(StringType),
isLegacyUnion = true,
)
)
)
assertEquals(
SnowflakeDataType.VARIANT.typeName,
snowflakeColumnUtils.toDialectType(
UnionType(
options = emptySet(),
isLegacyUnion = false,
)
)
)
assertEquals(
SnowflakeDataType.VARIANT.typeName,
snowflakeColumnUtils.toDialectType(UnknownType(schema = mockk<JsonNode>()))
)
}
@ParameterizedTest
@CsvSource(
value =
[
"$COLUMN_NAME_DATA, true, \"$COLUMN_NAME_DATA\"",
"some-other_Column, true, \"SOME-OTHER_COLUMN\"",
"$COLUMN_NAME_DATA, false, $COLUMN_NAME_DATA",
"some-other_Column, false, SOME-OTHER_COLUMN",
"$COLUMN_NAME_DATA, true, \"$COLUMN_NAME_DATA\"",
"some-other_Column, true, \"SOME-OTHER_COLUMN\"",
]
)
fun testFormatColumnName(columnName: String, quote: Boolean, expectedFormattedName: String) {
assertEquals(
expectedFormattedName,
snowflakeColumnUtils.formatColumnName(columnName, quote)
)
}
}

View File

@@ -1,97 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
import io.airbyte.cdk.load.table.TableName
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.mockk.every
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
internal class SnowflakeSqlNameUtilsTest {
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
private lateinit var snowflakeSqlNameUtils: SnowflakeSqlNameUtils
@BeforeEach
fun setUp() {
snowflakeConfiguration = mockk(relaxed = true)
snowflakeSqlNameUtils =
SnowflakeSqlNameUtils(snowflakeConfiguration = snowflakeConfiguration)
}
@Test
fun testFullyQualifiedName() {
val databaseName = "test-database"
val namespace = "test-namespace"
val name = "test=name"
val tableName = TableName(namespace = namespace, name = name)
every { snowflakeConfiguration.database } returns databaseName
val expectedName =
snowflakeSqlNameUtils.combineParts(
listOf(
databaseName.toSnowflakeCompatibleName(),
tableName.namespace,
tableName.name
)
)
val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedName(tableName)
assertEquals(expectedName, fullyQualifiedName)
}
@Test
fun testFullyQualifiedNamespace() {
val databaseName = "test-database"
val namespace = "test-namespace"
every { snowflakeConfiguration.database } returns databaseName
val fullyQualifiedNamespace = snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)
assertEquals("\"TEST-DATABASE\".\"test-namespace\"", fullyQualifiedNamespace)
}
@Test
fun testFullyQualifiedStageName() {
val databaseName = "test-database"
val namespace = "test-namespace"
val name = "test=name"
val tableName = TableName(namespace = namespace, name = name)
every { snowflakeConfiguration.database } returns databaseName
val expectedName =
snowflakeSqlNameUtils.combineParts(
listOf(
databaseName.toSnowflakeCompatibleName(),
namespace,
"$STAGE_NAME_PREFIX$name"
)
)
val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName)
assertEquals(expectedName, fullyQualifiedName)
}
@Test
fun testFullyQualifiedStageNameWithEscape() {
val databaseName = "test-database"
val namespace = "test-namespace"
val name = "test=\"\"\'name"
val tableName = TableName(namespace = namespace, name = name)
every { snowflakeConfiguration.database } returns databaseName
val expectedName =
snowflakeSqlNameUtils.combineParts(
listOf(
databaseName.toSnowflakeCompatibleName(),
namespace,
"$STAGE_NAME_PREFIX${sqlEscape(name)}"
)
)
val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true)
assertEquals(expectedName, fullyQualifiedName)
}
}

View File

@@ -6,20 +6,21 @@ package io.airbyte.integrations.destination.snowflake.write
import io.airbyte.cdk.SystemErrorException
import io.airbyte.cdk.load.command.Append
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.orchestration.db.DatabaseInitialStatusGatherer
import io.airbyte.cdk.load.orchestration.db.TableNames
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadInitialStatus
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendStreamLoader
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendTruncateStreamLoader
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableStatus
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableNameInfo
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.schema.model.TableNames
import io.airbyte.cdk.load.table.DatabaseInitialStatusGatherer
import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.cdk.load.table.directload.DirectLoadInitialStatus
import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendStreamLoader
import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendTruncateStreamLoader
import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.table.directload.DirectLoadTableStatus
import io.airbyte.cdk.load.write.StreamStateStore
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.every
@@ -34,55 +35,93 @@ internal class SnowflakeWriterTest {
@Test
fun testSetup() {
val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
val stream = mockk<DestinationStream>()
val tableInfo =
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream =
mockk<DestinationStream> {
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
every { importType } returns Append
}
val catalog = DestinationCatalog(listOf(stream))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns emptyMap()
coEvery { gatherInitialStatus() } returns
mapOf(
stream to
DirectLoadInitialStatus(
realTable = DirectLoadTableStatus(false),
tempTable = null
)
)
}
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer =
SnowflakeWriter(
names = catalog,
catalog = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(relaxed = true),
snowflakeConfiguration =
mockk(relaxed = true) {
every { internalTableSchema } returns "internal_schema"
},
)
runBlocking { writer.setup() }
coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) }
coVerify(exactly = 1) { stateGatherer.gatherInitialStatus(catalog) }
coVerify(exactly = 1) { stateGatherer.gatherInitialStatus() }
}
@Test
fun testCreateStreamLoaderFirstGeneration() {
val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream =
mockk<DestinationStream> {
every { minimumGenerationId } returns 0L
every { generationId } returns 0L
every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
}
val tableInfo =
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val catalog = DestinationCatalog(listOf(stream))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns
coEvery { gatherInitialStatus() } returns
mapOf(
stream to
DirectLoadInitialStatus(
@@ -93,14 +132,18 @@ internal class SnowflakeWriterTest {
}
val tempTableNameGenerator =
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer =
SnowflakeWriter(
names = catalog,
catalog = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient,
tempTableNameGenerator = tempTableNameGenerator,
snowflakeConfiguration = mockk(relaxed = true),
snowflakeConfiguration =
mockk(relaxed = true) {
every { internalTableSchema } returns "internal_schema"
},
)
runBlocking {
@@ -113,23 +156,35 @@ internal class SnowflakeWriterTest {
@Test
fun testCreateStreamLoaderNotFirstGeneration() {
val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream =
mockk<DestinationStream> {
every { minimumGenerationId } returns 1L
every { generationId } returns 1L
every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
}
val tableInfo =
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val catalog = DestinationCatalog(listOf(stream))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns
coEvery { gatherInitialStatus() } returns
mapOf(
stream to
DirectLoadInitialStatus(
@@ -140,14 +195,18 @@ internal class SnowflakeWriterTest {
}
val tempTableNameGenerator =
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer =
SnowflakeWriter(
names = catalog,
catalog = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient,
tempTableNameGenerator = tempTableNameGenerator,
snowflakeConfiguration = mockk(relaxed = true),
snowflakeConfiguration =
mockk(relaxed = true) {
every { internalTableSchema } returns "internal_schema"
},
)
runBlocking {
@@ -160,22 +219,35 @@ internal class SnowflakeWriterTest {
@Test
fun testCreateStreamLoaderHybrid() {
val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream =
mockk<DestinationStream> {
every { minimumGenerationId } returns 1L
every { generationId } returns 2L
every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
}
val tableInfo =
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val catalog = DestinationCatalog(listOf(stream))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns
coEvery { gatherInitialStatus() } returns
mapOf(
stream to
DirectLoadInitialStatus(
@@ -186,14 +258,18 @@ internal class SnowflakeWriterTest {
}
val tempTableNameGenerator =
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer =
SnowflakeWriter(
names = catalog,
catalog = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient,
tempTableNameGenerator = tempTableNameGenerator,
snowflakeConfiguration = mockk(relaxed = true),
snowflakeConfiguration =
mockk(relaxed = true) {
every { internalTableSchema } returns "internal_schema"
},
)
runBlocking {
@@ -203,169 +279,126 @@ internal class SnowflakeWriterTest {
}
@Test
fun testSetupWithNamespaceCreationFailure() {
val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
val stream = mockk<DestinationStream>()
val tableInfo =
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val snowflakeClient = mockk<SnowflakeAirbyteClient>()
val stateGatherer = mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>>()
val writer =
SnowflakeWriter(
names = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(),
)
// Simulate network failure during namespace creation
coEvery {
snowflakeClient.createNamespace(tableName.namespace.toSnowflakeCompatibleName())
} throws RuntimeException("Network connection failed")
assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } }
}
@Test
fun testSetupWithInitialStatusGatheringFailure() {
val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
val stream = mockk<DestinationStream>()
val tableInfo =
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer = mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>>()
val writer =
SnowflakeWriter(
names = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(),
)
// Simulate failure while gathering initial status
coEvery { stateGatherer.gatherInitialStatus(catalog) } throws
RuntimeException("Failed to query table status")
assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } }
// Verify namespace creation was still attempted
coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) }
}
@Test
fun testCreateStreamLoaderWithMissingInitialStatus() {
val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
fun testCreateStreamLoaderNamespaceLegacy() {
val namespace = "test-namespace"
val name = "test-name"
val tableName = TableName(namespace = namespace, name = name)
val tempTableName = TableName(namespace = namespace, name = "${name}-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream =
mockk<DestinationStream> {
every { minimumGenerationId } returns 0L
every { generationId } returns 0L
every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
}
val missingStream =
mockk<DestinationStream> {
every { minimumGenerationId } returns 0L
every { generationId } returns 0L
}
val tableInfo =
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val catalog = DestinationCatalog(listOf(stream))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns
coEvery { gatherInitialStatus() } returns
mapOf(
stream to
DirectLoadInitialStatus(
realTable = DirectLoadTableStatus(false),
tempTable = null,
tempTable = null
)
)
}
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer =
SnowflakeWriter(
names = catalog,
catalog = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(relaxed = true),
snowflakeConfiguration =
mockk(relaxed = true) {
every { legacyRawTablesOnly } returns true
every { internalTableSchema } returns "internal_schema"
},
)
runBlocking {
writer.setup()
// Try to create loader for a stream that wasn't in initial status
assertThrows(NullPointerException::class.java) {
writer.createStreamLoader(missingStream)
runBlocking { writer.setup() }
coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) }
}
@Test
fun testCreateStreamLoaderNamespaceNonLegacy() {
val namespace = "test-namespace"
val name = "test-name"
val tableName = TableName(namespace = namespace, name = name)
val tempTableName = TableName(namespace = namespace, name = "${name}-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream =
mockk<DestinationStream> {
every { minimumGenerationId } returns 0L
every { generationId } returns 0L
every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
}
}
}
@Test
fun testCreateStreamLoaderWithNullFinalTableName() {
// TableNames constructor throws IllegalStateException when both names are null
assertThrows(IllegalStateException::class.java) {
TableNames(rawTableName = null, finalTableName = null)
}
}
@Test
fun testSetupWithMultipleNamespaceFailuresPartial() {
val tableName1 = TableName(namespace = "namespace1", name = "table1")
val tableName2 = TableName(namespace = "namespace2", name = "table2")
val tableNames1 = TableNames(rawTableName = null, finalTableName = tableName1)
val tableNames2 = TableNames(rawTableName = null, finalTableName = tableName2)
val stream1 = mockk<DestinationStream>()
val stream2 = mockk<DestinationStream>()
val tableInfo1 =
TableNameInfo(
tableNames = tableNames1,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val tableInfo2 =
TableNameInfo(
tableNames = tableNames2,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream1 to tableInfo1, stream2 to tableInfo2))
val snowflakeClient = mockk<SnowflakeAirbyteClient>()
val stateGatherer = mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>>()
val catalog = DestinationCatalog(listOf(stream))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus() } returns
mapOf(
stream to
DirectLoadInitialStatus(
realTable = DirectLoadTableStatus(false),
tempTable = null
)
)
}
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer =
SnowflakeWriter(
names = catalog,
catalog = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(),
snowflakeConfiguration =
mockk(relaxed = true) {
every { legacyRawTablesOnly } returns false
every { internalTableSchema } returns "internal_schema"
},
)
// First namespace succeeds, second fails (namespaces are uppercased by
// toSnowflakeCompatibleName)
coEvery { snowflakeClient.createNamespace("namespace1") } returns Unit
coEvery { snowflakeClient.createNamespace("namespace2") } throws
RuntimeException("Connection timeout")
runBlocking { writer.setup() }
assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } }
// Verify both namespace creations were attempted
coVerify(exactly = 1) { snowflakeClient.createNamespace("namespace1") }
coVerify(exactly = 1) { snowflakeClient.createNamespace("namespace2") }
coVerify(exactly = 1) { snowflakeClient.createNamespace(namespace) }
}
}

View File

@@ -4,218 +4,220 @@
package io.airbyte.integrations.destination.snowflake.write.load
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerValue
import io.airbyte.cdk.load.data.NullValue
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.mockk.coVerify
import io.mockk.every
import io.mockk.mockk
import java.io.BufferedReader
import java.io.File
import java.io.InputStreamReader
import java.util.zip.GZIPInputStream
import kotlin.io.path.exists
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNotNull
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
internal class SnowflakeInsertBufferTest {
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
private lateinit var columnManager: SnowflakeColumnManager
private lateinit var columnSchema: ColumnSchema
private lateinit var snowflakeRecordFormatter: SnowflakeRecordFormatter
@BeforeEach
fun setUp() {
snowflakeConfiguration = mockk(relaxed = true)
snowflakeColumnUtils = mockk(relaxed = true)
snowflakeRecordFormatter = SnowflakeSchemaRecordFormatter()
columnManager =
mockk(relaxed = true) {
every { getMetaColumns() } returns
linkedMapOf(
"_AIRBYTE_RAW_ID" to ColumnType("VARCHAR", false),
"_AIRBYTE_EXTRACTED_AT" to ColumnType("TIMESTAMP_TZ", false),
"_AIRBYTE_META" to ColumnType("VARIANT", false),
"_AIRBYTE_GENERATION_ID" to ColumnType("NUMBER", true)
)
every { getTableColumnNames(any()) } returns
listOf(
"_AIRBYTE_RAW_ID",
"_AIRBYTE_EXTRACTED_AT",
"_AIRBYTE_META",
"_AIRBYTE_GENERATION_ID",
"columnName"
)
}
}
@Test
fun testAccumulate() {
val tableName = mockk<TableName>(relaxed = true)
val tableName = TableName(namespace = "test", name = "table")
val column = "columnName"
val columns = linkedMapOf(column to "NUMBER(38,0)")
columnSchema =
ColumnSchema(
inputToFinalColumnNames = mapOf(column to column.uppercase()),
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord(column)
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
columnSchema = columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
)
buffer.accumulate(record)
assertEquals(true, buffer.csvFilePath?.exists())
assertEquals(0, buffer.recordCount)
runBlocking { buffer.accumulate(record) }
assertEquals(1, buffer.recordCount)
}
@Test
fun testAccumulateRaw() {
val tableName = mockk<TableName>(relaxed = true)
fun testFlushToStaging() {
val tableName = TableName(namespace = "test", name = "table")
val column = "columnName"
val columns = linkedMapOf(column to "NUMBER(38,0)")
columnSchema =
ColumnSchema(
inputToFinalColumnNames = mapOf(column to column.uppercase()),
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord(column)
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
columnSchema = columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
)
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
buffer.accumulate(record)
assertEquals(true, buffer.csvFilePath?.exists())
assertEquals(1, buffer.recordCount)
}
@Test
fun testFlush() {
val tableName = mockk<TableName>(relaxed = true)
val column = "columnName"
val columns = linkedMapOf(column to "NUMBER(38,0)")
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord(column)
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
val expectedColumnNames =
listOf(
"_AIRBYTE_RAW_ID",
"_AIRBYTE_EXTRACTED_AT",
"_AIRBYTE_META",
"_AIRBYTE_GENERATION_ID",
"columnName"
)
runBlocking {
buffer.accumulate(record)
buffer.flush()
}
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
coVerify(exactly = 1) {
snowflakeAirbyteClient.copyFromStage(
tableName,
match { it.endsWith("$CSV_FILE_EXTENSION$FILE_SUFFIX") }
)
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
coVerify(exactly = 1) {
snowflakeAirbyteClient.copyFromStage(tableName, any(), expectedColumnNames)
}
}
}
@Test
fun testFlushRaw() {
val tableName = mockk<TableName>(relaxed = true)
fun testFlushToNoStaging() {
val tableName = TableName(namespace = "test", name = "table")
val column = "columnName"
val columns = linkedMapOf(column to "NUMBER(38,0)")
columnSchema =
ColumnSchema(
inputToFinalColumnNames = mapOf(column to column.uppercase()),
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val record = createRecord(column)
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
columnSchema = columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
flushLimit = 1,
)
val expectedColumnNames =
listOf(
"_AIRBYTE_RAW_ID",
"_AIRBYTE_EXTRACTED_AT",
"_AIRBYTE_META",
"_AIRBYTE_GENERATION_ID",
"columnName"
)
runBlocking {
buffer.accumulate(record)
buffer.flush()
// In legacy raw mode, it still uses staging
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
coVerify(exactly = 1) {
snowflakeAirbyteClient.copyFromStage(tableName, any(), expectedColumnNames)
}
}
}
@Test
fun testFileCreation() {
val tableName = TableName(namespace = "test", name = "table")
val column = "columnName"
columnSchema =
ColumnSchema(
inputToFinalColumnNames = mapOf(column to column.uppercase()),
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord(column)
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
columnSchema = columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
)
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
runBlocking {
buffer.accumulate(record)
buffer.flush()
}
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
coVerify(exactly = 1) {
snowflakeAirbyteClient.copyFromStage(
tableName,
match { it.endsWith("$CSV_FILE_EXTENSION$FILE_SUFFIX") }
)
}
}
@Test
fun testMissingFields() {
val tableName = mockk<TableName>(relaxed = true)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord("COLUMN1")
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
columns = linkedMapOf("COLUMN1" to "NUMBER(38,0)", "COLUMN2" to "NUMBER(38,0)"),
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
)
runBlocking {
buffer.accumulate(record)
buffer.csvWriter?.flush()
// The csvFilePath is internal, we can access it for testing
val filepath = buffer.csvFilePath
assertNotNull(filepath)
val file = filepath!!.toFile()
assert(file.exists())
// Close the writer to ensure all data is flushed
buffer.csvWriter?.close()
assertEquals(
"test-value$CSV_FIELD_SEPARATOR$CSV_LINE_DELIMITER",
readFromCsvFile(buffer.csvFilePath!!.toFile())
)
val lines = mutableListOf<String>()
GZIPInputStream(file.inputStream()).use { gzip ->
BufferedReader(InputStreamReader(gzip)).use { bufferedReader ->
bufferedReader.forEachLine { line -> lines.add(line) }
}
}
assertEquals(1, lines.size)
file.delete()
}
}
@Test
fun testMissingFieldsRaw() {
val tableName = mockk<TableName>(relaxed = true)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord("COLUMN1")
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
columns = linkedMapOf("COLUMN1" to "NUMBER(38,0)", "COLUMN2" to "NUMBER(38,0)"),
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
)
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
runBlocking {
buffer.accumulate(record)
buffer.csvWriter?.flush()
buffer.csvWriter?.close()
assertEquals(
"test-value$CSV_FIELD_SEPARATOR$CSV_LINE_DELIMITER",
readFromCsvFile(buffer.csvFilePath!!.toFile())
)
}
}
private fun readFromCsvFile(file: File) =
GZIPInputStream(file.inputStream()).use { input ->
val reader = BufferedReader(InputStreamReader(input))
reader.readText()
}
private fun createRecord(columnName: String) =
mapOf(
columnName to AirbyteValue.from("test-value"),
Meta.COLUMN_NAME_AB_EXTRACTED_AT to IntegerValue(System.currentTimeMillis()),
Meta.COLUMN_NAME_AB_RAW_ID to StringValue("raw-id"),
Meta.COLUMN_NAME_AB_GENERATION_ID to IntegerValue(1223),
Meta.COLUMN_NAME_AB_META to StringValue("{\"changes\":[],\"syncId\":43}"),
"${columnName}Null" to NullValue
private fun createRecord(column: String): Map<String, AirbyteValue> {
return mapOf(
column to IntegerValue(value = 42),
Meta.COLUMN_NAME_AB_GENERATION_ID to NullValue,
Meta.COLUMN_NAME_AB_RAW_ID to StringValue("raw-id-1"),
Meta.COLUMN_NAME_AB_EXTRACTED_AT to IntegerValue(1234567890),
Meta.COLUMN_NAME_AB_META to StringValue("meta-data-foo"),
)
}
}

View File

@@ -15,12 +15,8 @@ import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.mockk.every
import io.mockk.mockk
import kotlin.collections.plus
import io.airbyte.cdk.load.schema.model.ColumnSchema
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
private val AIRBYTE_COLUMN_TYPES_MAP =
@@ -58,28 +54,16 @@ private fun createExpected(
internal class SnowflakeRawRecordFormatterTest {
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
@BeforeEach
fun setup() {
snowflakeColumnUtils = mockk {
every { getFormattedDefaultColumnNames(any()) } returns
AIRBYTE_COLUMN_TYPES_MAP.keys.toList()
}
}
@Test
fun testFormatting() {
val columnName = "test-column-name"
val columnValue = "test-column-value"
val columns = AIRBYTE_COLUMN_TYPES_MAP
val record = createRecord(columnName = columnName, columnValue = columnValue)
val formatter =
SnowflakeRawRecordFormatter(
columns = AIRBYTE_COLUMN_TYPES_MAP,
snowflakeColumnUtils = snowflakeColumnUtils
)
val formattedValue = formatter.format(record)
val formatter = SnowflakeRawRecordFormatter()
// RawRecordFormatter doesn't use columnSchema but still needs one per interface
val dummyColumnSchema = ColumnSchema(emptyMap(), emptyMap(), emptyMap())
val formattedValue = formatter.format(record, dummyColumnSchema)
val expectedValue =
createExpected(
record = record,
@@ -93,33 +77,28 @@ internal class SnowflakeRawRecordFormatterTest {
fun testFormattingMigratedFromPreviousVersion() {
val columnName = "test-column-name"
val columnValue = "test-column-value"
val columnsMap =
linkedMapOf(
COLUMN_NAME_AB_EXTRACTED_AT to "TIMESTAMP_TZ(9)",
COLUMN_NAME_AB_LOADED_AT to "TIMESTAMP_TZ(9)",
COLUMN_NAME_AB_META to "VARIANT",
COLUMN_NAME_DATA to "VARIANT",
COLUMN_NAME_AB_RAW_ID to "VARCHAR(16777216)",
COLUMN_NAME_AB_GENERATION_ID to "NUMBER(38,0)",
)
val record = createRecord(columnName = columnName, columnValue = columnValue)
val formatter =
SnowflakeRawRecordFormatter(
columns = columnsMap,
snowflakeColumnUtils = snowflakeColumnUtils
)
val formattedValue = formatter.format(record)
val formatter = SnowflakeRawRecordFormatter()
// RawRecordFormatter doesn't use columnSchema but still needs one per interface
val dummyColumnSchema = ColumnSchema(emptyMap(), emptyMap(), emptyMap())
val formattedValue = formatter.format(record, dummyColumnSchema)
// The formatter outputs in a fixed order regardless of input column order:
// 1. AB_RAW_ID
// 2. AB_EXTRACTED_AT
// 3. AB_META
// 4. AB_GENERATION_ID
// 5. AB_LOADED_AT
// 6. DATA (JSON with remaining columns)
val expectedValue =
createExpected(
record = record,
columns = columnsMap,
airbyteColumns = columnsMap.keys.toList(),
)
.toMutableList()
expectedValue.add(
columnsMap.keys.indexOf(COLUMN_NAME_DATA),
"{\"$columnName\":\"$columnValue\"}"
)
listOf(
record[COLUMN_NAME_AB_RAW_ID]!!.toCsvValue(),
record[COLUMN_NAME_AB_EXTRACTED_AT]!!.toCsvValue(),
record[COLUMN_NAME_AB_META]!!.toCsvValue(),
record[COLUMN_NAME_AB_GENERATION_ID]!!.toCsvValue(),
record[COLUMN_NAME_AB_LOADED_AT]!!.toCsvValue(),
"{\"$columnName\":\"$columnValue\"}"
)
assertEquals(expectedValue, formattedValue)
}
}

View File

@@ -4,66 +4,58 @@
package io.airbyte.integrations.destination.snowflake.write.load
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerValue
import io.airbyte.cdk.load.data.NullValue
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.data.csv.toCsvValue
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.mockk.every
import io.mockk.mockk
import java.util.AbstractMap
import kotlin.collections.component1
import kotlin.collections.component2
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
private val AIRBYTE_COLUMN_TYPES_MAP =
linkedMapOf(
COLUMN_NAME_AB_RAW_ID to "VARCHAR(16777216)",
COLUMN_NAME_AB_EXTRACTED_AT to "TIMESTAMP_TZ(9)",
COLUMN_NAME_AB_META to "VARIANT",
COLUMN_NAME_AB_GENERATION_ID to "NUMBER(38,0)",
)
.mapKeys { it.key.toSnowflakeCompatibleName() }
internal class SnowflakeSchemaRecordFormatterTest {
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
private fun createColumnSchema(userColumns: Map<String, String>): ColumnSchema {
val finalSchema = linkedMapOf<String, ColumnType>()
val inputToFinalColumnNames = mutableMapOf<String, String>()
val inputSchema = mutableMapOf<String, FieldType>()
@BeforeEach
fun setup() {
snowflakeColumnUtils = mockk {
every { getFormattedDefaultColumnNames(any()) } returns
AIRBYTE_COLUMN_TYPES_MAP.keys.toList()
// Add user columns
userColumns.forEach { (name, type) ->
val finalName = name.toSnowflakeCompatibleName()
finalSchema[finalName] = ColumnType(type, true)
inputToFinalColumnNames[name] = finalName
inputSchema[name] = FieldType(StringType, nullable = true)
}
return ColumnSchema(
inputToFinalColumnNames = inputToFinalColumnNames,
finalSchema = finalSchema,
inputSchema = inputSchema
)
}
@Test
fun testFormatting() {
val columnName = "test-column-name"
val columnValue = "test-column-value"
val columns =
(AIRBYTE_COLUMN_TYPES_MAP + linkedMapOf(columnName to "VARCHAR(16777216)")).mapKeys {
it.key.toSnowflakeCompatibleName()
}
val userColumns = mapOf(columnName to "VARCHAR(16777216)")
val columnSchema = createColumnSchema(userColumns)
val record = createRecord(columnName, columnValue)
val formatter =
SnowflakeSchemaRecordFormatter(
columns = columns as LinkedHashMap<String, String>,
snowflakeColumnUtils = snowflakeColumnUtils
)
val formattedValue = formatter.format(record)
val formatter = SnowflakeSchemaRecordFormatter()
val formattedValue = formatter.format(record, columnSchema)
val expectedValue =
createExpected(
record = record,
columns = columns,
columnSchema = columnSchema,
)
assertEquals(expectedValue, formattedValue)
}
@@ -72,21 +64,15 @@ internal class SnowflakeSchemaRecordFormatterTest {
fun testFormattingVariant() {
val columnName = "test-column-name"
val columnValue = "{\"test\": \"test-value\"}"
val columns =
(AIRBYTE_COLUMN_TYPES_MAP + linkedMapOf(columnName to "VARIANT")).mapKeys {
it.key.toSnowflakeCompatibleName()
}
val userColumns = mapOf(columnName to "VARIANT")
val columnSchema = createColumnSchema(userColumns)
val record = createRecord(columnName, columnValue)
val formatter =
SnowflakeSchemaRecordFormatter(
columns = columns as LinkedHashMap<String, String>,
snowflakeColumnUtils = snowflakeColumnUtils
)
val formattedValue = formatter.format(record)
val formatter = SnowflakeSchemaRecordFormatter()
val formattedValue = formatter.format(record, columnSchema)
val expectedValue =
createExpected(
record = record,
columns = columns,
columnSchema = columnSchema,
)
assertEquals(expectedValue, formattedValue)
}
@@ -95,23 +81,16 @@ internal class SnowflakeSchemaRecordFormatterTest {
fun testFormattingMissingColumn() {
val columnName = "test-column-name"
val columnValue = "test-column-value"
val columns =
AIRBYTE_COLUMN_TYPES_MAP +
linkedMapOf(
columnName to "VARCHAR(16777216)",
"missing-column" to "VARCHAR(16777216)"
)
val userColumns =
mapOf(columnName to "VARCHAR(16777216)", "missing-column" to "VARCHAR(16777216)")
val columnSchema = createColumnSchema(userColumns)
val record = createRecord(columnName, columnValue)
val formatter =
SnowflakeSchemaRecordFormatter(
columns = columns as LinkedHashMap<String, String>,
snowflakeColumnUtils = snowflakeColumnUtils
)
val formattedValue = formatter.format(record)
val formatter = SnowflakeSchemaRecordFormatter()
val formattedValue = formatter.format(record, columnSchema)
val expectedValue =
createExpected(
record = record,
columns = columns,
columnSchema = columnSchema,
filterMissing = false,
)
assertEquals(expectedValue, formattedValue)
@@ -128,16 +107,37 @@ internal class SnowflakeSchemaRecordFormatterTest {
private fun createExpected(
record: Map<String, AirbyteValue>,
columns: Map<String, String>,
columnSchema: ColumnSchema,
filterMissing: Boolean = true,
) =
record.entries
.associate { entry -> entry.key.toSnowflakeCompatibleName() to entry.value }
.map { entry -> AbstractMap.SimpleEntry(entry.key, entry.value.toCsvValue()) }
.sortedBy { entry ->
if (columns.keys.indexOf(entry.key) > -1) columns.keys.indexOf(entry.key)
else Int.MAX_VALUE
): List<Any> {
val columns = columnSchema.finalSchema.keys.toList()
val result = mutableListOf<Any>()
// Add meta columns first in the expected order
result.add(record[COLUMN_NAME_AB_RAW_ID]?.toCsvValue() ?: "")
result.add(record[COLUMN_NAME_AB_EXTRACTED_AT]?.toCsvValue() ?: "")
result.add(record[COLUMN_NAME_AB_META]?.toCsvValue() ?: "")
result.add(record[COLUMN_NAME_AB_GENERATION_ID]?.toCsvValue() ?: "")
// Add user columns
val userColumns =
columns.filterNot { col ->
listOf(
COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName(),
COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName(),
COLUMN_NAME_AB_META.toSnowflakeCompatibleName(),
COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName()
)
.contains(col)
}
.filter { (k, _) -> if (filterMissing) columns.contains(k) else true }
.map { it.value }
userColumns.forEach { columnName ->
val value = record[columnName] ?: if (!filterMissing) NullValue else null
if (value != null || !filterMissing) {
result.add(value?.toCsvValue() ?: "")
}
}
return result
}
}

Some files were not shown because too many files have changed in this diff Show More