1
0
mirror of synced 2025-12-19 18:14:56 -05:00

Merge branch 'master' into devin/1765393705-fix-youtube-analytics-job-creation

This commit is contained in:
Aaron ("AJ") Steers
2025-12-18 16:11:00 -08:00
committed by GitHub
1402 changed files with 83973 additions and 10058 deletions

View File

@@ -139,8 +139,8 @@ runs:
CONNECTOR_VERSION_TAG="${{ inputs.tag-override }}" CONNECTOR_VERSION_TAG="${{ inputs.tag-override }}"
echo "🏷 Using provided tag override: $CONNECTOR_VERSION_TAG" echo "🏷 Using provided tag override: $CONNECTOR_VERSION_TAG"
elif [[ "${{ inputs.release-type }}" == "pre-release" ]]; then elif [[ "${{ inputs.release-type }}" == "pre-release" ]]; then
hash=$(git rev-parse --short=10 HEAD) hash=$(git rev-parse --short=7 HEAD)
CONNECTOR_VERSION_TAG="${CONNECTOR_VERSION}-dev.${hash}" CONNECTOR_VERSION_TAG="${CONNECTOR_VERSION}-preview.${hash}"
echo "🏷 Using pre-release tag: $CONNECTOR_VERSION_TAG" echo "🏷 Using pre-release tag: $CONNECTOR_VERSION_TAG"
else else
CONNECTOR_VERSION_TAG="$CONNECTOR_VERSION" CONNECTOR_VERSION_TAG="$CONNECTOR_VERSION"

View File

@@ -21,7 +21,7 @@ As needed or by request, Airbyte Maintainers can execute the following slash com
- `/run-live-tests` - Runs live tests for the modified connector(s). - `/run-live-tests` - Runs live tests for the modified connector(s).
- `/run-regression-tests` - Runs regression tests for the modified connector(s). - `/run-regression-tests` - Runs regression tests for the modified connector(s).
- `/build-connector-images` - Builds and publishes a pre-release docker image for the modified connector(s). - `/build-connector-images` - Builds and publishes a pre-release docker image for the modified connector(s).
- `/publish-connectors-prerelease` - Publishes pre-release connector builds (tagged as `{version}-dev.{git-sha}`) for all modified connectors in the PR. - `/publish-connectors-prerelease` - Publishes pre-release connector builds (tagged as `{version}-preview.{git-sha}`) for all modified connectors in the PR.
If you have any questions, feel free to ask in the PR comments or join our [Slack community](https://airbytehq.slack.com/). If you have any questions, feel free to ask in the PR comments or join our [Slack community](https://airbytehq.slack.com/).

View File

@@ -28,7 +28,11 @@ Airbyte Maintainers (that's you!) can execute the following slash commands on yo
- `/run-live-tests` - Runs live tests for the modified connector(s). - `/run-live-tests` - Runs live tests for the modified connector(s).
- `/run-regression-tests` - Runs regression tests for the modified connector(s). - `/run-regression-tests` - Runs regression tests for the modified connector(s).
- `/build-connector-images` - Builds and publishes a pre-release docker image for the modified connector(s). - `/build-connector-images` - Builds and publishes a pre-release docker image for the modified connector(s).
- `/publish-connectors-prerelease` - Publishes pre-release connector builds (tagged as `{version}-dev.{git-sha}`) for all modified connectors in the PR. - `/publish-connectors-prerelease` - Publishes pre-release connector builds (tagged as `{version}-preview.{git-sha}`) for all modified connectors in the PR.
- Connector release lifecycle (AI-powered):
- `/ai-prove-fix` - Runs prerelease readiness checks, including testing against customer connections.
- `/ai-canary-prerelease` - Rolls out prerelease to 5-10 connections for canary testing.
- `/ai-release-watch` - Monitors rollout post-release and tracks sync success rates.
- JVM connectors: - JVM connectors:
- `/update-connector-cdk-version connector=<CONNECTOR_NAME>` - Updates the specified connector to the latest CDK version. - `/update-connector-cdk-version connector=<CONNECTOR_NAME>` - Updates the specified connector to the latest CDK version.
Example: `/update-connector-cdk-version connector=destination-bigquery` Example: `/update-connector-cdk-version connector=destination-bigquery`

View File

@@ -0,0 +1,72 @@
name: AI Canary Prerelease Command
on:
workflow_dispatch:
inputs:
pr:
description: "Pull request number (if triggered from a PR)"
type: number
required: false
comment-id:
description: "The comment-id of the slash command. Used to update the comment with the status."
required: false
repo:
description: "Repo (passed by slash command dispatcher)"
required: false
default: "airbytehq/airbyte"
gitref:
description: "Git ref (passed by slash command dispatcher)"
required: false
run-name: "AI Canary Prerelease for PR #${{ github.event.inputs.pr }}"
permissions:
contents: read
issues: write
pull-requests: read
jobs:
ai-canary-prerelease:
runs-on: ubuntu-latest
steps:
- name: Get job variables
id: job-vars
run: |
echo "run-url=https://github.com/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID" >> $GITHUB_OUTPUT
- name: Checkout code
uses: actions/checkout@v4
- name: Authenticate as GitHub App
uses: actions/create-github-app-token@v2
id: get-app-token
with:
owner: "airbytehq"
repositories: "airbyte,oncall"
app-id: ${{ secrets.OCTAVIA_BOT_APP_ID }}
private-key: ${{ secrets.OCTAVIA_BOT_PRIVATE_KEY }}
- name: Post start comment
if: inputs.comment-id != ''
uses: peter-evans/create-or-update-comment@v4
with:
token: ${{ steps.get-app-token.outputs.token }}
comment-id: ${{ inputs.comment-id }}
issue-number: ${{ inputs.pr }}
body: |
> **AI Canary Prerelease Started**
>
> Rolling out to 5-10 connections, watching results, and reporting findings.
> [View workflow run](${{ steps.job-vars.outputs.run-url }})
- name: Run AI Canary Prerelease
uses: aaronsteers/devin-action@main
with:
comment-id: ${{ inputs.comment-id }}
issue-number: ${{ inputs.pr }}
playbook-macro: "!canary_prerelease"
devin-token: ${{ secrets.DEVIN_AI_API_KEY }}
github-token: ${{ steps.get-app-token.outputs.token }}
start-message: "🐤 **AI Canary Prerelease session starting...** Rolling out to 5-10 connections, watching results, and reporting findings. [View playbook](https://github.com/airbytehq/oncall/blob/main/prompts/playbooks/canary_prerelease.md)"
tags: |
ai-oncall

View File

@@ -0,0 +1,72 @@
name: AI Prove Fix Command
on:
workflow_dispatch:
inputs:
pr:
description: "Pull request number (if triggered from a PR)"
type: number
required: false
comment-id:
description: "The comment-id of the slash command. Used to update the comment with the status."
required: false
repo:
description: "Repo (passed by slash command dispatcher)"
required: false
default: "airbytehq/airbyte"
gitref:
description: "Git ref (passed by slash command dispatcher)"
required: false
run-name: "AI Prove Fix for PR #${{ github.event.inputs.pr }}"
permissions:
contents: read
issues: write
pull-requests: read
jobs:
ai-prove-fix:
runs-on: ubuntu-latest
steps:
- name: Get job variables
id: job-vars
run: |
echo "run-url=https://github.com/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID" >> $GITHUB_OUTPUT
- name: Checkout code
uses: actions/checkout@v4
- name: Authenticate as GitHub App
uses: actions/create-github-app-token@v2
id: get-app-token
with:
owner: "airbytehq"
repositories: "airbyte,oncall"
app-id: ${{ secrets.OCTAVIA_BOT_APP_ID }}
private-key: ${{ secrets.OCTAVIA_BOT_PRIVATE_KEY }}
- name: Post start comment
if: inputs.comment-id != ''
uses: peter-evans/create-or-update-comment@v4
with:
token: ${{ steps.get-app-token.outputs.token }}
comment-id: ${{ inputs.comment-id }}
issue-number: ${{ inputs.pr }}
body: |
> **AI Prove Fix Started**
>
> Running readiness checks and testing against customer connections.
> [View workflow run](${{ steps.job-vars.outputs.run-url }})
- name: Run AI Prove Fix
uses: aaronsteers/devin-action@main
with:
comment-id: ${{ inputs.comment-id }}
issue-number: ${{ inputs.pr }}
playbook-macro: "!prove_fix"
devin-token: ${{ secrets.DEVIN_AI_API_KEY }}
github-token: ${{ steps.get-app-token.outputs.token }}
start-message: "🔍 **AI Prove Fix session starting...** Running readiness checks and testing against customer connections. [View playbook](https://github.com/airbytehq/oncall/blob/main/prompts/playbooks/prove_fix.md)"
tags: |
ai-oncall

View File

@@ -0,0 +1,72 @@
name: AI Release Watch Command
on:
workflow_dispatch:
inputs:
pr:
description: "Pull request number (if triggered from a PR)"
type: number
required: false
comment-id:
description: "The comment-id of the slash command. Used to update the comment with the status."
required: false
repo:
description: "Repo (passed by slash command dispatcher)"
required: false
default: "airbytehq/airbyte"
gitref:
description: "Git ref (passed by slash command dispatcher)"
required: false
run-name: "AI Release Watch for PR #${{ github.event.inputs.pr }}"
permissions:
contents: read
issues: write
pull-requests: read
jobs:
ai-release-watch:
runs-on: ubuntu-latest
steps:
- name: Get job variables
id: job-vars
run: |
echo "run-url=https://github.com/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID" >> $GITHUB_OUTPUT
- name: Checkout code
uses: actions/checkout@v4
- name: Authenticate as GitHub App
uses: actions/create-github-app-token@v2
id: get-app-token
with:
owner: "airbytehq"
repositories: "airbyte,oncall"
app-id: ${{ secrets.OCTAVIA_BOT_APP_ID }}
private-key: ${{ secrets.OCTAVIA_BOT_PRIVATE_KEY }}
- name: Post start comment
if: inputs.comment-id != ''
uses: peter-evans/create-or-update-comment@v4
with:
token: ${{ steps.get-app-token.outputs.token }}
comment-id: ${{ inputs.comment-id }}
issue-number: ${{ inputs.pr }}
body: |
> **AI Release Watch Started**
>
> Monitoring rollout and tracking sync success rates.
> [View workflow run](${{ steps.job-vars.outputs.run-url }})
- name: Run AI Release Watch
uses: aaronsteers/devin-action@main
with:
comment-id: ${{ inputs.comment-id }}
issue-number: ${{ inputs.pr }}
playbook-macro: "!release_watch"
devin-token: ${{ secrets.DEVIN_AI_API_KEY }}
github-token: ${{ steps.get-app-token.outputs.token }}
start-message: "👁️ **AI Release Watch session starting...** Monitoring rollout and tracking sync success rates. [View playbook](https://github.com/airbytehq/oncall/blob/main/prompts/playbooks/release_watch.md)"
tags: |
ai-oncall

View File

@@ -104,7 +104,7 @@ jobs:
if: steps.check-support-level.outputs.metadata_file == 'true' && steps.check-support-level.outputs.community_support == 'true' if: steps.check-support-level.outputs.metadata_file == 'true' && steps.check-support-level.outputs.community_support == 'true'
env: env:
PROMPT_TEXT: "The commit to review is ${{ github.sha }}. This commit was pushed to master and may contain connector changes that need documentation updates." PROMPT_TEXT: "The commit to review is ${{ github.sha }}. This commit was pushed to master and may contain connector changes that need documentation updates."
uses: aaronsteers/devin-action@0d74d6d9ff1b16ada5966dc31af53a9d155759f4 # Pinned to specific commit for security uses: aaronsteers/devin-action@98d15ae93d1848914f5ab8e9ce45341182958d27 # v0.1.7 - Pinned to specific commit for security
with: with:
devin-token: ${{ secrets.DEVIN_AI_API_KEY }} devin-token: ${{ secrets.DEVIN_AI_API_KEY }}
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -0,0 +1,28 @@
name: Label Community PRs
# This workflow automatically adds the "community" label to PRs from forks.
# This enables automatic tracking on the Community PRs project board.
on:
pull_request_target:
types:
- opened
- reopened
jobs:
label-community-pr:
name: Add "Community" Label to PR
# Only run for PRs from forks
if: github.event.pull_request.head.repo.fork == true
runs-on: ubuntu-24.04
permissions:
issues: write
pull-requests: write
steps:
- name: Add community label
# This action uses GitHub's addLabels API, which is idempotent.
# If the label already exists, the API call succeeds without error.
uses: actions-ecosystem/action-add-labels@bd52874380e3909a1ac983768df6976535ece7f8 # v1.1.3
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
labels: community

View File

@@ -3,7 +3,7 @@ name: Publish Connectors Pre-release
# It can be triggered via the /publish-connectors-prerelease slash command from PR comments, # It can be triggered via the /publish-connectors-prerelease slash command from PR comments,
# or via the MCP tool `publish_connector_to_airbyte_registry`. # or via the MCP tool `publish_connector_to_airbyte_registry`.
# #
# Pre-release versions are tagged with the format: {version}-dev.{10-char-git-sha} # Pre-release versions are tagged with the format: {version}-preview.{7-char-git-sha}
# These versions are NOT eligible for semver auto-advancement but ARE available # These versions are NOT eligible for semver auto-advancement but ARE available
# for version pinning via the scoped_configuration API. # for version pinning via the scoped_configuration API.
# #
@@ -66,7 +66,7 @@ jobs:
- name: Get short SHA - name: Get short SHA
id: get-sha id: get-sha
run: | run: |
SHORT_SHA=$(git rev-parse --short=10 HEAD) SHORT_SHA=$(git rev-parse --short=7 HEAD)
echo "short-sha=$SHORT_SHA" >> $GITHUB_OUTPUT echo "short-sha=$SHORT_SHA" >> $GITHUB_OUTPUT
- name: Get job variables - name: Get job variables
@@ -135,7 +135,7 @@ jobs:
> Publishing pre-release build for connector `${{ steps.resolve-connector.outputs.connector-name }}`. > Publishing pre-release build for connector `${{ steps.resolve-connector.outputs.connector-name }}`.
> Branch: `${{ inputs.gitref }}` > Branch: `${{ inputs.gitref }}`
> >
> Pre-release versions will be tagged as `{version}-dev.${{ steps.get-sha.outputs.short-sha }}` > Pre-release versions will be tagged as `{version}-preview.${{ steps.get-sha.outputs.short-sha }}`
> and are available for version pinning via the scoped_configuration API. > and are available for version pinning via the scoped_configuration API.
> >
> [View workflow run](${{ steps.job-vars.outputs.run-url }}) > [View workflow run](${{ steps.job-vars.outputs.run-url }})
@@ -147,6 +147,7 @@ jobs:
with: with:
connectors: ${{ format('--name={0}', needs.init.outputs.connector-name) }} connectors: ${{ format('--name={0}', needs.init.outputs.connector-name) }}
release-type: pre-release release-type: pre-release
gitref: ${{ inputs.gitref }}
secrets: inherit secrets: inherit
post-completion: post-completion:
@@ -176,13 +177,12 @@ jobs:
id: message-vars id: message-vars
run: | run: |
CONNECTOR_NAME="${{ needs.init.outputs.connector-name }}" CONNECTOR_NAME="${{ needs.init.outputs.connector-name }}"
SHORT_SHA="${{ needs.init.outputs.short-sha }}" # Use the actual docker-image-tag from the publish workflow output
VERSION="${{ needs.init.outputs.connector-version }}" DOCKER_TAG="${{ needs.publish.outputs.docker-image-tag }}"
if [[ -n "$VERSION" ]]; then if [[ -z "$DOCKER_TAG" ]]; then
DOCKER_TAG="${VERSION}-dev.${SHORT_SHA}" echo "::error::docker-image-tag output is missing from publish workflow. This is unexpected."
else exit 1
DOCKER_TAG="{version}-dev.${SHORT_SHA}"
fi fi
echo "connector_name=$CONNECTOR_NAME" >> $GITHUB_OUTPUT echo "connector_name=$CONNECTOR_NAME" >> $GITHUB_OUTPUT

View File

@@ -21,6 +21,14 @@ on:
required: false required: false
default: false default: false
type: boolean type: boolean
gitref:
description: "Git ref (branch or SHA) to build connectors from. Used by pre-release workflow to build from PR branches."
required: false
type: string
outputs:
docker-image-tag:
description: "Docker image tag used when publishing. For single-connector callers only; multi-connector callers should not rely on this output."
value: ${{ jobs.publish_connector_registry_entries.outputs.docker-image-tag }}
workflow_dispatch: workflow_dispatch:
inputs: inputs:
connectors: connectors:
@@ -48,6 +56,7 @@ jobs:
# v4 # v4
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955
with: with:
ref: ${{ inputs.gitref || '' }}
fetch-depth: 2 # Required so we can conduct a diff from the previous commit to understand what connectors have changed. fetch-depth: 2 # Required so we can conduct a diff from the previous commit to understand what connectors have changed.
submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully. submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully.
- name: List connectors to publish [manual] - name: List connectors to publish [manual]
@@ -105,6 +114,7 @@ jobs:
# v4 # v4
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955
with: with:
ref: ${{ inputs.gitref || '' }}
fetch-depth: 2 # Required so we can conduct a diff from the previous commit to understand what connectors have changed. fetch-depth: 2 # Required so we can conduct a diff from the previous commit to understand what connectors have changed.
submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully. submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully.
@@ -250,11 +260,14 @@ jobs:
max-parallel: 5 max-parallel: 5
# Allow all jobs to run, even if one fails # Allow all jobs to run, even if one fails
fail-fast: false fail-fast: false
outputs:
docker-image-tag: ${{ steps.connector-metadata.outputs.docker-image-tag }}
steps: steps:
- name: Checkout Airbyte - name: Checkout Airbyte
# v4 # v4
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955
with: with:
ref: ${{ inputs.gitref || '' }}
fetch-depth: 2 # Required so we can conduct a diff from the previous commit to understand what connectors have changed. fetch-depth: 2 # Required so we can conduct a diff from the previous commit to understand what connectors have changed.
submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully. submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully.
@@ -292,8 +305,8 @@ jobs:
echo "connector-version=$(poe -qq get-version)" | tee -a $GITHUB_OUTPUT echo "connector-version=$(poe -qq get-version)" | tee -a $GITHUB_OUTPUT
CONNECTOR_VERSION=$(poe -qq get-version) CONNECTOR_VERSION=$(poe -qq get-version)
if [[ "${{ inputs.release-type }}" == "pre-release" ]]; then if [[ "${{ inputs.release-type }}" == "pre-release" ]]; then
hash=$(git rev-parse --short=10 HEAD) hash=$(git rev-parse --short=7 HEAD)
echo "docker-image-tag=${CONNECTOR_VERSION}-dev.${hash}" | tee -a $GITHUB_OUTPUT echo "docker-image-tag=${CONNECTOR_VERSION}-preview.${hash}" | tee -a $GITHUB_OUTPUT
echo "release-type-flag=--pre-release" | tee -a $GITHUB_OUTPUT echo "release-type-flag=--pre-release" | tee -a $GITHUB_OUTPUT
else else
echo "docker-image-tag=${CONNECTOR_VERSION}" | tee -a $GITHUB_OUTPUT echo "docker-image-tag=${CONNECTOR_VERSION}" | tee -a $GITHUB_OUTPUT
@@ -349,6 +362,7 @@ jobs:
# v4 # v4
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955
with: with:
ref: ${{ inputs.gitref || '' }}
submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully. submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully.
- name: Match GitHub User to Slack User - name: Match GitHub User to Slack User
id: match-github-to-slack-user id: match-github-to-slack-user
@@ -381,6 +395,7 @@ jobs:
# v4 # v4
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955
with: with:
ref: ${{ inputs.gitref || '' }}
submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully. submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully.
- name: Notify PagerDuty - name: Notify PagerDuty
id: pager-duty id: pager-duty

View File

@@ -35,6 +35,9 @@ jobs:
issue-type: both issue-type: both
commands: | commands: |
ai-canary-prerelease
ai-prove-fix
ai-release-watch
approve-regression-tests approve-regression-tests
bump-bulk-cdk-version bump-bulk-cdk-version
bump-progressive-rollout-version bump-progressive-rollout-version

View File

@@ -0,0 +1,70 @@
name: Sync Agent Connector Docs
on:
schedule:
- cron: "0 */2 * * *" # Every 2 hours
workflow_dispatch: # Manual trigger
jobs:
sync-docs:
runs-on: ubuntu-latest
steps:
- name: Checkout airbyte repo
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
- name: Checkout airbyte-agent-connectors
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
repository: airbytehq/airbyte-agent-connectors
path: agent-connectors-source
- name: Sync connector docs
run: |
DEST_DIR="docs/ai-agents/connectors"
mkdir -p "$DEST_DIR"
for connector_dir in agent-connectors-source/connectors/*/; do
connector=$(basename "$connector_dir")
# Only delete/recreate the specific connector subdirectory
# This leaves any files directly in $DEST_DIR untouched
rm -rf "$DEST_DIR/$connector"
mkdir -p "$DEST_DIR/$connector"
# Copy all markdown files for this connector
for md_file in "$connector_dir"/*.md; do
if [ -f "$md_file" ]; then
cp "$md_file" "$DEST_DIR/$connector/"
fi
done
done
echo "Synced $(ls -d $DEST_DIR/*/ 2>/dev/null | wc -l) connectors"
- name: Cleanup temporary checkout
run: rm -rf agent-connectors-source
- name: Authenticate as GitHub App
uses: actions/create-github-app-token@v2
id: get-app-token
with:
owner: "airbytehq"
repositories: "airbyte"
app-id: ${{ secrets.OCTAVIA_BOT_APP_ID }}
private-key: ${{ secrets.OCTAVIA_BOT_PRIVATE_KEY }}
- name: Create PR if changes
uses: peter-evans/create-pull-request@0979079bc20c05bbbb590a56c21c4e2b1d1f1bbe # v6
with:
token: ${{ steps.get-app-token.outputs.token }}
commit-message: "docs: sync agent connector docs from airbyte-agent-connectors repo"
branch: auto-sync-ai-connector-docs
delete-branch: true
title: "docs: sync agent connector docs from airbyte-agent-connectors repo"
body: |
Automated sync of agent connector docs from airbyte-agent-connectors.
This PR was automatically created by the sync-agent-connector-docs workflow.
labels: |
documentation
auto-merge

3
.markdownlintignore Normal file
View File

@@ -0,0 +1,3 @@
# Ignore auto-generated connector documentation files synced from airbyte-agent-connectors repo
# These files are generated and have formatting that doesn't conform to markdownlint rules
docs/ai-agents/connectors/**

View File

@@ -1,3 +1,21 @@
## Version 0.1.91
load cdk: upsert records test uses proper target schema
## Version 0.1.90
load cdk: components tests: data coercion tests cover all data types
## Version 0.1.89
load cdk: components tests: data coercion tests for int+number
## Version 0.1.88
**Load CDK**
* Add CDC_CURSOR_COLUMN_NAME constant.
## Version 0.1.87 ## Version 0.1.87
**Load CDK** **Load CDK**

View File

@@ -4,4 +4,13 @@
package io.airbyte.cdk.load.table package io.airbyte.cdk.load.table
/**
* CDC meta column names.
*
* Note: These CDC column names are brittle as they are separate yet coupled to the logic sources
* use to generate these column names. See
* [io.airbyte.integrations.source.mssql.MsSqlSourceOperations.MsSqlServerCdcMetaFields] for an
* example.
*/
const val CDC_DELETED_AT_COLUMN = "_ab_cdc_deleted_at" const val CDC_DELETED_AT_COLUMN = "_ab_cdc_deleted_at"
const val CDC_CURSOR_COLUMN = "_ab_cdc_cursor"

View File

@@ -0,0 +1,859 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.component
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.ArrayValue
import io.airbyte.cdk.load.data.DateValue
import io.airbyte.cdk.load.data.IntegerValue
import io.airbyte.cdk.load.data.NullValue
import io.airbyte.cdk.load.data.NumberValue
import io.airbyte.cdk.load.data.ObjectValue
import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.data.TimeWithTimezoneValue
import io.airbyte.cdk.load.data.TimeWithoutTimezoneValue
import io.airbyte.cdk.load.data.TimestampWithTimezoneValue
import io.airbyte.cdk.load.data.TimestampWithoutTimezoneValue
import io.airbyte.cdk.load.dataflow.transform.ValueCoercer
import io.airbyte.cdk.load.util.serializeToString
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Reason
import java.math.BigDecimal
import java.math.BigInteger
import java.time.LocalDate
import java.time.LocalDateTime
import java.time.OffsetDateTime
import java.time.format.DateTimeFormatter
import java.time.format.DateTimeFormatterBuilder
import java.time.format.SignStyle
import java.time.temporal.ChronoField
import org.junit.jupiter.params.provider.Arguments
/*
* This file defines "interesting values" for all data types, along with expected behavior for those values.
* You're free to define your own values/behavior depending on the destination, but it's recommended
* that you try to match behavior to an existing fixture.
*
* Classes also include some convenience functions for JUnit. For example, you could annotate your
* method with:
* ```kotlin
* @ParameterizedTest
* @MethodSource("io.airbyte.cdk.load.component.DataCoercionIntegerFixtures#int64")
* ```
*
* By convention, all fixtures are declared as:
* 1. One or more `val <name>: List<Pair<AirbyteValue, Any?>>` (each pair representing the input value,
* and the expected output value)
* 2. One or more `fun <name>(): List<Arguments> = <name>.toArgs()`, which can be provided to JUnit's MethodSource
*
* If you need to mutate fixtures in some way, you should reference the `val`, and use the `toArgs()`
* extension function to convert it to JUnit's Arguments class. See [DataCoercionIntegerFixtures.int64AsBigInteger]
* for an example.
*/
object DataCoercionIntegerFixtures {
// "9".repeat(38)
val numeric38_0Max = bigint("99999999999999999999999999999999999999")
val numeric38_0Min = bigint("-99999999999999999999999999999999999999")
const val ZERO = "0"
const val ONE = "1"
const val NEGATIVE_ONE = "-1"
const val FORTY_TWO = "42"
const val NEGATIVE_FORTY_TWO = "-42"
const val INT32_MAX = "int32 max"
const val INT32_MIN = "int32 min"
const val INT32_MAX_PLUS_ONE = "int32_max + 1"
const val INT32_MIN_MINUS_ONE = "int32_min - 1"
const val INT64_MAX = "int64 max"
const val INT64_MIN = "int64 min"
const val INT64_MAX_PLUS_ONE = "int64_max + 1"
const val INT64_MIN_MINUS_1 = "int64_min - 1"
const val NUMERIC_38_0_MAX = "numeric(38,0) max"
const val NUMERIC_38_0_MIN = "numeric(38,0) min"
const val NUMERIC_38_0_MAX_PLUS_ONE = "numeric(38,0)_max + 1"
const val NUMERIC_38_0_MIN_MINUS_ONE = "numeric(38,0)_min - 1"
/**
* Many destinations use int64 to represent integers. In this case, we null out any value beyond
* Long.MIN/MAX_VALUE.
*/
val int64 =
listOf(
case(NULL, NullValue, null),
case(ZERO, IntegerValue(0), 0L),
case(ONE, IntegerValue(1), 1L),
case(NEGATIVE_ONE, IntegerValue(-1), -1L),
case(FORTY_TWO, IntegerValue(42), 42L),
case(NEGATIVE_FORTY_TWO, IntegerValue(-42), -42L),
// int32 bounds, and slightly out of bounds
case(INT32_MAX, IntegerValue(Integer.MAX_VALUE.toLong()), Integer.MAX_VALUE.toLong()),
case(INT32_MIN, IntegerValue(Integer.MIN_VALUE.toLong()), Integer.MIN_VALUE.toLong()),
case(
INT32_MAX_PLUS_ONE,
IntegerValue(Integer.MAX_VALUE.toLong() + 1),
Integer.MAX_VALUE.toLong() + 1
),
case(
INT32_MIN_MINUS_ONE,
IntegerValue(Integer.MIN_VALUE.toLong() - 1),
Integer.MIN_VALUE.toLong() - 1
),
// int64 bounds, and slightly out of bounds
case(INT64_MAX, IntegerValue(Long.MAX_VALUE), Long.MAX_VALUE),
case(INT64_MIN, IntegerValue(Long.MIN_VALUE), Long.MIN_VALUE),
// values out of int64 bounds are nulled
case(
INT64_MAX_PLUS_ONE,
IntegerValue(bigint(Long.MAX_VALUE) + BigInteger.ONE),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
INT64_MIN_MINUS_1,
IntegerValue(bigint(Long.MIN_VALUE) - BigInteger.ONE),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
// NUMERIC(38, 9) bounds, and slightly out of bounds
// (these are all out of bounds for an int64 value, so they all get nulled)
case(
NUMERIC_38_0_MAX,
IntegerValue(numeric38_0Max),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
NUMERIC_38_0_MIN,
IntegerValue(numeric38_0Min),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
NUMERIC_38_0_MAX_PLUS_ONE,
IntegerValue(numeric38_0Max + BigInteger.ONE),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
NUMERIC_38_0_MIN_MINUS_ONE,
IntegerValue(numeric38_0Min - BigInteger.ONE),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
)
/**
* Many destination warehouses represent integers as a fixed-point type with 38 digits of
* precision. In this case, we only need to null out numbers larger than `1e38 - 1` / smaller
* than `-1e38 + 1`.
*/
val numeric38_0 =
listOf(
case(NULL, NullValue, null),
case(ZERO, IntegerValue(0), bigint(0L)),
case(ONE, IntegerValue(1), bigint(1L)),
case(NEGATIVE_ONE, IntegerValue(-1), bigint(-1L)),
case(FORTY_TWO, IntegerValue(42), bigint(42L)),
case(NEGATIVE_FORTY_TWO, IntegerValue(-42), bigint(-42L)),
// int32 bounds, and slightly out of bounds
case(
INT32_MAX,
IntegerValue(Integer.MAX_VALUE.toLong()),
bigint(Integer.MAX_VALUE.toLong())
),
case(
INT32_MIN,
IntegerValue(Integer.MIN_VALUE.toLong()),
bigint(Integer.MIN_VALUE.toLong())
),
case(
INT32_MAX_PLUS_ONE,
IntegerValue(Integer.MAX_VALUE.toLong() + 1),
bigint(Integer.MAX_VALUE.toLong() + 1)
),
case(
INT32_MIN_MINUS_ONE,
IntegerValue(Integer.MIN_VALUE.toLong() - 1),
bigint(Integer.MIN_VALUE.toLong() - 1)
),
// int64 bounds, and slightly out of bounds
case(INT64_MAX, IntegerValue(Long.MAX_VALUE), bigint(Long.MAX_VALUE)),
case(INT64_MIN, IntegerValue(Long.MIN_VALUE), bigint(Long.MIN_VALUE)),
case(
INT64_MAX_PLUS_ONE,
IntegerValue(bigint(Long.MAX_VALUE) + BigInteger.ONE),
bigint(Long.MAX_VALUE) + BigInteger.ONE
),
case(
INT64_MIN_MINUS_1,
IntegerValue(bigint(Long.MIN_VALUE) - BigInteger.ONE),
bigint(Long.MIN_VALUE) - BigInteger.ONE
),
// NUMERIC(38, 9) bounds, and slightly out of bounds
case(NUMERIC_38_0_MAX, IntegerValue(numeric38_0Max), numeric38_0Max),
case(NUMERIC_38_0_MIN, IntegerValue(numeric38_0Min), numeric38_0Min),
// These values exceed the 38-digit range, so they get nulled out
case(
NUMERIC_38_0_MAX_PLUS_ONE,
IntegerValue(numeric38_0Max + BigInteger.ONE),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
NUMERIC_38_0_MIN_MINUS_ONE,
IntegerValue(numeric38_0Min - BigInteger.ONE),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
)
@JvmStatic fun int64() = int64.toArgs()
/**
* Convenience fixture if your [TestTableOperationsClient] returns integers as [BigInteger]
* rather than [Long].
*/
@JvmStatic
fun int64AsBigInteger() =
int64.map { it.copy(outputValue = it.outputValue?.let { bigint(it as Long) }) }
/**
* Convenience fixture if your [TestTableOperationsClient] returns integers as [BigDecimal]
* rather than [Long].
*/
@JvmStatic
fun int64AsBigDecimal() =
int64.map { it.copy(outputValue = it.outputValue?.let { BigDecimal.valueOf(it as Long) }) }
@JvmStatic fun numeric38_0() = numeric38_0.toArgs()
}
object DataCoercionNumberFixtures {
val numeric38_9Max = bigdec("99999999999999999999999999999.999999999")
val numeric38_9Min = bigdec("-99999999999999999999999999999.999999999")
const val ZERO = "0"
const val ONE = "1"
const val NEGATIVE_ONE = "-1"
const val ONE_HUNDRED_TWENTY_THREE_POINT_FOUR = "123.4"
const val NEGATIVE_ONE_HUNDRED_TWENTY_THREE_POINT_FOUR = "123.4"
const val POSITIVE_HIGH_PRECISION_FLOAT = "positive high-precision float"
const val NEGATIVE_HIGH_PRECISION_FLOAT = "negative high-precision float"
const val NUMERIC_38_9_MAX = "numeric(38,9) max"
const val NUMERIC_38_9_MIN = "numeric(38,9) min"
const val SMALLEST_POSITIVE_FLOAT32 = "smallest positive float32"
const val SMALLEST_NEGATIVE_FLOAT32 = "smallest negative float32"
const val LARGEST_POSITIVE_FLOAT32 = "largest positive float32"
const val LARGEST_NEGATIVE_FLOAT32 = "largest negative float32"
const val SMALLEST_POSITIVE_FLOAT64 = "smallest positive float64"
const val SMALLEST_NEGATIVE_FLOAT64 = "smallest negative float64"
const val LARGEST_POSITIVE_FLOAT64 = "largest positive float64"
const val LARGEST_NEGATIVE_FLOAT64 = "largest negative float64"
const val SLIGHTLY_ABOVE_LARGEST_POSITIVE_FLOAT64 = "slightly above largest positive float64"
const val SLIGHTLY_BELOW_LARGEST_NEGATIVE_FLOAT64 = "slightly below largest negative float64"
val float64 =
listOf(
case(NULL, NullValue, null),
case(ZERO, NumberValue(bigdec(0)), 0.0),
case(ONE, NumberValue(bigdec(1)), 1.0),
case(NEGATIVE_ONE, NumberValue(bigdec(-1)), -1.0),
// This value isn't exactly representable as a float64
// (the exact value is `123.400000000000005684341886080801486968994140625`)
// but we should preserve the canonical representation
case(ONE_HUNDRED_TWENTY_THREE_POINT_FOUR, NumberValue(bigdec("123.4")), 123.4),
case(
NEGATIVE_ONE_HUNDRED_TWENTY_THREE_POINT_FOUR,
NumberValue(bigdec("-123.4")),
-123.4
),
// These values have too much precision for a float64, so we round them
case(
POSITIVE_HIGH_PRECISION_FLOAT,
NumberValue(bigdec("1234567890.1234567890123456789")),
1234567890.1234567,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
NEGATIVE_HIGH_PRECISION_FLOAT,
NumberValue(bigdec("-1234567890.1234567890123456789")),
-1234567890.1234567,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
NUMERIC_38_9_MAX,
NumberValue(numeric38_9Max),
1.0E29,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
NUMERIC_38_9_MIN,
NumberValue(numeric38_9Min),
-1.0E29,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
// min/max_value are all positive values, so we need to manually test their negative
// version
case(
SMALLEST_POSITIVE_FLOAT32,
NumberValue(bigdec(Float.MIN_VALUE.toDouble())),
Float.MIN_VALUE.toDouble()
),
case(
SMALLEST_NEGATIVE_FLOAT32,
NumberValue(bigdec(-Float.MIN_VALUE.toDouble())),
-Float.MIN_VALUE.toDouble()
),
case(
LARGEST_POSITIVE_FLOAT32,
NumberValue(bigdec(Float.MAX_VALUE.toDouble())),
Float.MAX_VALUE.toDouble()
),
case(
LARGEST_NEGATIVE_FLOAT32,
NumberValue(bigdec(-Float.MAX_VALUE.toDouble())),
-Float.MAX_VALUE.toDouble()
),
case(
SMALLEST_POSITIVE_FLOAT64,
NumberValue(bigdec(Double.MIN_VALUE)),
Double.MIN_VALUE
),
case(
SMALLEST_NEGATIVE_FLOAT64,
NumberValue(bigdec(-Double.MIN_VALUE)),
-Double.MIN_VALUE
),
case(LARGEST_POSITIVE_FLOAT64, NumberValue(bigdec(Double.MAX_VALUE)), Double.MAX_VALUE),
case(
LARGEST_NEGATIVE_FLOAT64,
NumberValue(bigdec(-Double.MAX_VALUE)),
-Double.MAX_VALUE
),
// These values are out of bounds, so we null them
case(
SLIGHTLY_ABOVE_LARGEST_POSITIVE_FLOAT64,
NumberValue(bigdec(Double.MAX_VALUE) + bigdec(Double.MIN_VALUE)),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
SLIGHTLY_BELOW_LARGEST_NEGATIVE_FLOAT64,
NumberValue(bigdec(-Double.MAX_VALUE) - bigdec(Double.MIN_VALUE)),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
)
val numeric38_9 =
listOf(
case(NULL, NullValue, null),
case(ZERO, NumberValue(bigdec(0)), bigdec(0.0)),
case(ONE, NumberValue(bigdec(1)), bigdec(1.0)),
case(NEGATIVE_ONE, NumberValue(bigdec(-1)), bigdec(-1.0)),
// This value isn't exactly representable as a float64
// (the exact value is `123.400000000000005684341886080801486968994140625`)
// but it's perfectly fine as a numeric(38, 9)
case(
ONE_HUNDRED_TWENTY_THREE_POINT_FOUR,
NumberValue(bigdec("123.4")),
bigdec("123.4")
),
case(
NEGATIVE_ONE_HUNDRED_TWENTY_THREE_POINT_FOUR,
NumberValue(bigdec("-123.4")),
bigdec("-123.4")
),
// These values have too much precision for a numeric(38, 9), so we round them
case(
POSITIVE_HIGH_PRECISION_FLOAT,
NumberValue(bigdec("1234567890.1234567890123456789")),
bigdec("1234567890.123456789"),
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
NEGATIVE_HIGH_PRECISION_FLOAT,
NumberValue(bigdec("-1234567890.1234567890123456789")),
bigdec("-1234567890.123456789"),
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
SMALLEST_POSITIVE_FLOAT32,
NumberValue(bigdec(Float.MIN_VALUE.toDouble())),
bigdec(0),
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
SMALLEST_NEGATIVE_FLOAT32,
NumberValue(bigdec(-Float.MIN_VALUE.toDouble())),
bigdec(0),
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
SMALLEST_POSITIVE_FLOAT64,
NumberValue(bigdec(Double.MIN_VALUE)),
bigdec(0),
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
SMALLEST_NEGATIVE_FLOAT64,
NumberValue(bigdec(-Double.MIN_VALUE)),
bigdec(0),
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
// numeric bounds are perfectly fine
case(NUMERIC_38_9_MAX, NumberValue(numeric38_9Max), numeric38_9Max),
case(NUMERIC_38_9_MIN, NumberValue(numeric38_9Min), numeric38_9Min),
// These values are out of bounds, so we null them
case(
LARGEST_POSITIVE_FLOAT32,
NumberValue(bigdec(Float.MAX_VALUE.toDouble())),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
LARGEST_NEGATIVE_FLOAT32,
NumberValue(bigdec(-Float.MAX_VALUE.toDouble())),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
LARGEST_POSITIVE_FLOAT64,
NumberValue(bigdec(Double.MAX_VALUE)),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
LARGEST_NEGATIVE_FLOAT64,
NumberValue(bigdec(-Double.MAX_VALUE)),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
SLIGHTLY_ABOVE_LARGEST_POSITIVE_FLOAT64,
NumberValue(bigdec(Double.MAX_VALUE) + bigdec(Double.MIN_VALUE)),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
SLIGHTLY_BELOW_LARGEST_NEGATIVE_FLOAT64,
NumberValue(bigdec(-Double.MAX_VALUE) - bigdec(Double.MIN_VALUE)),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
)
.map { it.copy(outputValue = (it.outputValue as BigDecimal?)?.setScale(9)) }
@JvmStatic fun float64() = float64.toArgs()
@JvmStatic fun numeric38_9() = numeric38_9.toArgs()
}
const val SIMPLE_TIMESTAMP = "simple timestamp"
const val UNIX_EPOCH = "unix epoch"
const val MINIMUM_TIMESTAMP = "minimum timestamp"
const val MAXIMUM_TIMESTAMP = "maximum timestamp"
const val OUT_OF_RANGE_TIMESTAMP = "out of range timestamp"
const val HIGH_PRECISION_TIMESTAMP = "high-precision timestamp"
object DataCoercionTimestampTzFixtures {
/**
* Many warehouses support timestamps between years 0001 - 9999.
*
* Depending on the exact warehouse, you may need to tweak the precision on some values. For
* example, Snowflake supports nanoseconds-precision timestamps (9 decimal points), but Bigquery
* only supports microseconds-precision (6 decimal points). Bigquery would probably do something
* like:
* ```kotlin
* DataCoercionNumberFixtures.traditionalWarehouse
* .map {
* when (it.name) {
* "maximum AD timestamp" -> it.copy(
* inputValue = TimestampWithTimezoneValue("9999-12-31T23:59:59.999999Z"),
* outputValue = OffsetDateTime.parse("9999-12-31T23:59:59.999999Z"),
* changeReason = Reason.DESTINATION_FIELD_SIZE_LIMITATION,
* )
* "high-precision timestamp" -> it.copy(
* outputValue = OffsetDateTime.parse("2025-01-23T01:01:00.123456Z"),
* changeReason = Reason.DESTINATION_FIELD_SIZE_LIMITATION,
* )
* }
* }
* ```
*/
val commonWarehouse =
listOf(
case(NULL, NullValue, null),
case(
SIMPLE_TIMESTAMP,
TimestampWithTimezoneValue("2025-01-23T12:34:56.789Z"),
"2025-01-23T12:34:56.789Z",
),
case(
UNIX_EPOCH,
TimestampWithTimezoneValue("1970-01-01T00:00:00Z"),
"1970-01-01T00:00:00Z",
),
case(
MINIMUM_TIMESTAMP,
TimestampWithTimezoneValue("0001-01-01T00:00:00Z"),
"0001-01-01T00:00:00Z",
),
case(
MAXIMUM_TIMESTAMP,
TimestampWithTimezoneValue("9999-12-31T23:59:59.999999999Z"),
"9999-12-31T23:59:59.999999999Z",
),
case(
OUT_OF_RANGE_TIMESTAMP,
TimestampWithTimezoneValue(odt("10000-01-01T00:00Z")),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION,
),
case(
HIGH_PRECISION_TIMESTAMP,
TimestampWithTimezoneValue("2025-01-23T01:01:00.123456789Z"),
"2025-01-23T01:01:00.123456789Z",
),
)
@JvmStatic fun commonWarehouse() = commonWarehouse.toArgs()
}
object DataCoercionTimestampNtzFixtures {
/** See [DataCoercionTimestampTzFixtures.commonWarehouse] for explanation */
val commonWarehouse =
listOf(
case(NULL, NullValue, null),
case(
SIMPLE_TIMESTAMP,
TimestampWithoutTimezoneValue("2025-01-23T12:34:56.789"),
"2025-01-23T12:34:56.789",
),
case(
UNIX_EPOCH,
TimestampWithoutTimezoneValue("1970-01-01T00:00:00"),
"1970-01-01T00:00:00",
),
case(
MINIMUM_TIMESTAMP,
TimestampWithoutTimezoneValue("0001-01-01T00:00:00"),
"0001-01-01T00:00:00",
),
case(
MAXIMUM_TIMESTAMP,
TimestampWithoutTimezoneValue("9999-12-31T23:59:59.999999999"),
"9999-12-31T23:59:59.999999999",
),
case(
OUT_OF_RANGE_TIMESTAMP,
TimestampWithoutTimezoneValue(ldt("10000-01-01T00:00")),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION,
),
case(
HIGH_PRECISION_TIMESTAMP,
TimestampWithoutTimezoneValue("2025-01-23T01:01:00.123456789"),
"2025-01-23T01:01:00.123456789",
),
)
@JvmStatic fun commonWarehouse() = commonWarehouse.toArgs()
}
const val MIDNIGHT = "midnight"
const val MAX_TIME = "max time"
const val HIGH_NOON = "high noon"
object DataCoercionTimeTzFixtures {
val timetz =
listOf(
case(NULL, NullValue, null),
case(MIDNIGHT, TimeWithTimezoneValue("00:00Z"), "00:00Z"),
case(MAX_TIME, TimeWithTimezoneValue("23:59:59.999999999Z"), "23:59:59.999999999Z"),
case(HIGH_NOON, TimeWithTimezoneValue("12:00Z"), "12:00Z"),
)
@JvmStatic fun timetz() = timetz.toArgs()
}
object DataCoercionTimeNtzFixtures {
val timentz =
listOf(
case(NULL, NullValue, null),
case(MIDNIGHT, TimeWithoutTimezoneValue("00:00"), "00:00"),
case(MAX_TIME, TimeWithoutTimezoneValue("23:59:59.999999999"), "23:59:59.999999999"),
case(HIGH_NOON, TimeWithoutTimezoneValue("12:00"), "12:00"),
)
@JvmStatic fun timentz() = timentz.toArgs()
}
object DataCoercionDateFixtures {
val commonWarehouse =
listOf(
case(NULL, NullValue, null),
case(
SIMPLE_TIMESTAMP,
DateValue("2025-01-23"),
"2025-01-23",
),
case(
UNIX_EPOCH,
DateValue("1970-01-01"),
"1970-01-01",
),
case(
MINIMUM_TIMESTAMP,
DateValue("0001-01-01"),
"0001-01-01",
),
case(
MAXIMUM_TIMESTAMP,
DateValue("9999-12-31"),
"9999-12-31",
),
case(
OUT_OF_RANGE_TIMESTAMP,
DateValue(date("10000-01-01")),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION,
),
)
@JvmStatic fun commonWarehouse() = commonWarehouse.toArgs()
}
object DataCoercionStringFixtures {
const val EMPTY_STRING = "empty string"
const val SHORT_STRING = "short string"
const val LONG_STRING = "long string"
const val SPECIAL_CHARS_STRING = "special chars string"
val strings =
listOf(
case(NULL, NullValue, null),
case(EMPTY_STRING, StringValue(""), ""),
case(SHORT_STRING, StringValue("foo"), "foo"),
// Implementers may override this to test their destination-specific limits.
// The default value is 8MB + 1 byte (slightly longer than snowflake's varchar limit).
case(
LONG_STRING,
StringValue("a".repeat(16777216 + 1)),
null,
Reason.DESTINATION_FIELD_SIZE_LIMITATION
),
case(
SPECIAL_CHARS_STRING,
StringValue("`~!@#$%^&*()-=_+[]\\{}|o'O\",./<>?)Δ⅀↑∀"),
"`~!@#$%^&*()-=_+[]\\{}|o'O\",./<>?)Δ⅀↑∀"
),
)
@JvmStatic fun strings() = strings.toArgs()
}
object DataCoercionObjectFixtures {
const val EMPTY_OBJECT = "empty object"
const val NORMAL_OBJECT = "normal object"
val objects =
listOf(
case(NULL, NullValue, null),
case(EMPTY_OBJECT, ObjectValue(linkedMapOf()), emptyMap<String, Any?>()),
case(
NORMAL_OBJECT,
ObjectValue(linkedMapOf("foo" to StringValue("bar"))),
mapOf("foo" to "bar")
),
)
val stringifiedObjects =
objects.map { fixture ->
fixture.copy(outputValue = fixture.outputValue?.serializeToString())
}
@JvmStatic fun objects() = objects.toArgs()
@JvmStatic fun stringifiedObjects() = stringifiedObjects.toArgs()
}
object DataCoercionArrayFixtures {
const val EMPTY_ARRAY = "empty array"
const val NORMAL_ARRAY = "normal array"
val arrays =
listOf(
case(NULL, NullValue, null),
case(EMPTY_ARRAY, ArrayValue(emptyList()), emptyList<Any?>()),
case(NORMAL_ARRAY, ArrayValue(listOf(StringValue("foo"))), listOf("foo")),
)
val stringifiedArrays =
arrays.map { fixture ->
fixture.copy(outputValue = fixture.outputValue?.serializeToString())
}
@JvmStatic fun arrays() = arrays.toArgs()
@JvmStatic fun stringifiedArrays() = stringifiedArrays.toArgs()
}
const val UNION_INT_VALUE = "int value"
const val UNION_OBJ_VALUE = "object value"
const val UNION_STR_VALUE = "string value"
object DataCoercionUnionFixtures {
val unions =
listOf(
case(NULL, NullValue, null),
case(UNION_INT_VALUE, IntegerValue(42), 42L),
case(UNION_STR_VALUE, StringValue("foo"), "foo"),
case(
UNION_OBJ_VALUE,
ObjectValue(linkedMapOf("foo" to StringValue("bar"))),
mapOf("foo" to "bar")
),
)
val stringifiedUnions =
unions.map { fixture ->
fixture.copy(outputValue = fixture.outputValue?.serializeToString())
}
@JvmStatic fun unions() = unions.toArgs()
@JvmStatic fun stringifiedUnions() = stringifiedUnions.toArgs()
}
object DataCoercionLegacyUnionFixtures {
val unions =
listOf(
case(NULL, NullValue, null),
// Legacy union of int x object will select object, and you can't write an int to an
// object column.
// So we should null it out.
case(UNION_INT_VALUE, IntegerValue(42), null, Reason.DESTINATION_TYPECAST_ERROR),
// Similarly, we should null out strings.
case(UNION_STR_VALUE, StringValue("foo"), "foo"),
// But objects can be written as objects, so retain this value.
case(
UNION_OBJ_VALUE,
ObjectValue(linkedMapOf("foo" to StringValue("bar"))),
mapOf("foo" to "bar")
),
)
val stringifiedUnions =
DataCoercionUnionFixtures.unions.map { fixture ->
fixture.copy(outputValue = fixture.outputValue?.serializeToString())
}
@JvmStatic fun unions() = unions.toArgs()
@JvmStatic fun stringifiedUnions() = DataCoercionUnionFixtures.stringifiedUnions.toArgs()
}
// This is pretty much identical to UnionFixtures, but separating them in case we need to add
// different test cases for either of them.
object DataCoercionUnknownFixtures {
const val INT_VALUE = "integer value"
const val STR_VALUE = "string value"
const val OBJ_VALUE = "object value"
val unknowns =
listOf(
case(NULL, NullValue, null),
case(INT_VALUE, IntegerValue(42), 42L),
case(STR_VALUE, StringValue("foo"), "foo"),
case(
OBJ_VALUE,
ObjectValue(linkedMapOf("foo" to StringValue("bar"))),
mapOf("foo" to "bar")
),
)
val stringifiedUnknowns =
unknowns.map { fixture ->
fixture.copy(outputValue = fixture.outputValue?.serializeToString())
}
@JvmStatic fun unknowns() = unknowns.toArgs()
@JvmStatic fun stringifiedUnknowns() = stringifiedUnknowns.toArgs()
}
fun List<DataCoercionTestCase>.toArgs(): List<Arguments> =
this.map { Arguments.argumentSet(it.name, it.inputValue, it.outputValue, it.changeReason) }
.toList()
/**
* Utility method to use the BigDecimal constructor (supports exponential notation like `1e38`) to
* construct a BigInteger.
*/
fun bigint(str: String): BigInteger = BigDecimal(str).toBigIntegerExact()
/** Shorthand utility method to construct a bigint from a long */
fun bigint(long: Long): BigInteger = BigInteger.valueOf(long)
fun bigdec(str: String): BigDecimal = BigDecimal(str)
fun bigdec(double: Double): BigDecimal = BigDecimal.valueOf(double)
fun bigdec(int: Int): BigDecimal = BigDecimal.valueOf(int.toDouble())
fun odt(str: String): OffsetDateTime = OffsetDateTime.parse(str, dateTimeFormatter)
fun ldt(str: String): LocalDateTime = LocalDateTime.parse(str, dateTimeFormatter)
fun date(str: String): LocalDate = LocalDate.parse(str, dateFormatter)
// The default java.time.*.parse() behavior only accepts up to 4-digit years.
// Build a custom formatter to handle larger years.
val dateFormatter =
DateTimeFormatterBuilder()
// java.time.* supports up to 9-digit years
.appendValue(ChronoField.YEAR, 1, 9, SignStyle.NORMAL)
.appendLiteral('-')
.appendValue(ChronoField.MONTH_OF_YEAR)
.appendLiteral('-')
.appendValue(ChronoField.DAY_OF_MONTH)
.toFormatter()
val dateTimeFormatter =
DateTimeFormatterBuilder()
.append(dateFormatter)
.appendLiteral('T')
// Accepts strings with/without an offset, so we can use this formatter
// for both timestamp with and without timezone
.append(DateTimeFormatter.ISO_TIME)
.toFormatter()
/**
* Represents a single data coercion test case. You probably want to use [case] as a shorthand
* constructor.
*
* @param name A short human-readable name for the test. Primarily useful for tests where
* [inputValue] is either very long, or otherwise hard to read.
* @param inputValue The value to pass into [ValueCoercer.validate]
* @param outputValue The value that we expect to read back from the destination. Should be
* basically equivalent to the output of [ValueCoercer.validate]
* @param changeReason If `validate` returns Truncate/Nullify, the reason for that
* truncation/nullification. If `validate` returns Valid, this should be null.
*/
data class DataCoercionTestCase(
val name: String,
val inputValue: AirbyteValue,
val outputValue: Any?,
val changeReason: Reason? = null,
)
fun case(
name: String,
inputValue: AirbyteValue,
outputValue: Any?,
changeReason: Reason? = null,
) = DataCoercionTestCase(name, inputValue, outputValue, changeReason)
const val NULL = "null"

View File

@@ -0,0 +1,369 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.component
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.ArrayType
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
import io.airbyte.cdk.load.data.BooleanType
import io.airbyte.cdk.load.data.BooleanValue
import io.airbyte.cdk.load.data.DateType
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerType
import io.airbyte.cdk.load.data.NumberType
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
import io.airbyte.cdk.load.data.UnionType
import io.airbyte.cdk.load.data.UnknownType
import io.airbyte.cdk.load.dataflow.transform.ValueCoercer
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.TableSchemaFactory
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.util.Jsons
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Reason
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import kotlinx.coroutines.test.runTest
/**
* The tests in this class are designed to reference the parameters defined in
* `DataCoercionFixtures.kt`. For example, you might annotate [`handle integer values`] with
* `@MethodSource("io.airbyte.cdk.load.component.DataCoercionIntegerFixtures#int32")`. See each
* fixture class for explanations of what behavior they are exercising.
*
* Note that this class _only_ exercises [ValueCoercer.validate]. You should write separate unit
* tests for [ValueCoercer.map]. For now, the `map` function is primarily intended for transforming
* `UnionType` fields into other types (typically `StringType`), at which point your `validate`
* implementation should be able to handle any StringValue (regardless of whether it was originally
* a StringType or UnionType).
*/
@MicronautTest(environments = ["component"], resolveParameters = false)
interface DataCoercionSuite {
val coercer: ValueCoercer
val airbyteMetaColumnMapping: Map<String, String>
get() = Meta.COLUMN_NAMES.associateWith { it }
val columnNameMapping: ColumnNameMapping
get() = ColumnNameMapping(mapOf("test" to "test"))
val opsClient: TableOperationsClient
val testClient: TestTableOperationsClient
val schemaFactory: TableSchemaFactory
val harness: TableOperationsTestHarness
get() =
TableOperationsTestHarness(
opsClient,
testClient,
schemaFactory,
airbyteMetaColumnMapping
)
/** Fixtures are defined in [DataCoercionIntegerFixtures]. */
fun `handle integer values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(IntegerType, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionNumberFixtures]. */
fun `handle number values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(NumberType, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionTimestampTzFixtures]. */
fun `handle timestamptz values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(TimestampTypeWithTimezone, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionTimestampNtzFixtures]. */
fun `handle timestampntz values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(TimestampTypeWithoutTimezone, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionTimeTzFixtures]. */
fun `handle timetz values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(TimeTypeWithTimezone, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionTimeNtzFixtures]. */
fun `handle timentz values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(TimeTypeWithoutTimezone, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionDateFixtures]. */
fun `handle date values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(DateType, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** No fixtures, hardcoded to just write `true` */
fun `handle bool values`(expectedValue: Any?) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(BooleanType, nullable = true),
// Just test on `true` and assume `false` also works
BooleanValue(true),
expectedValue,
// If your destination is nulling/truncating booleans... that's almost definitely a bug
expectedChangeReason = null,
)
}
/** Fixtures are defined in [DataCoercionStringFixtures]. */
fun `handle string values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(StringType, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionObjectFixtures]. */
fun `handle object values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(
ObjectType(linkedMapOf("foo" to FieldType(StringType, true))),
nullable = true
),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionObjectFixtures]. */
fun `handle empty object values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(ObjectTypeWithEmptySchema, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionObjectFixtures]. */
fun `handle schemaless object values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(ObjectTypeWithoutSchema, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionArrayFixtures]. */
fun `handle array values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(ArrayType(FieldType(StringType, true)), nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/** Fixtures are defined in [DataCoercionArrayFixtures]. */
fun `handle schemaless array values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(ArrayTypeWithoutSchema, nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/**
* All destinations should implement this, even if your destination is supporting legacy unions.
*
* Fixtures are defined in [DataCoercionUnionFixtures].
*/
fun `handle union values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(
UnionType(
setOf(
ObjectType(linkedMapOf("foo" to FieldType(StringType, true))),
IntegerType,
StringType,
),
isLegacyUnion = false
),
nullable = true
),
inputValue,
expectedValue,
expectedChangeReason,
)
}
/**
* Only legacy destinations that are maintaining "legacy" union behavior should implement this
* test. If you're not sure, check whether your `application-connector.yaml` includes a
* `airbyte.destination.core.types.unions: LEGACY` property.
*
* Fixtures are defined in [DataCoercionLegacyUnionFixtures].
*/
fun `handle legacy union values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(
UnionType(
setOf(
ObjectType(linkedMapOf("foo" to FieldType(StringType, true))),
IntegerType,
StringType,
),
isLegacyUnion = true
),
nullable = true
),
inputValue,
expectedValue,
expectedChangeReason,
)
}
fun `handle unknown values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) = runTest {
harness.testValueCoercion(
coercer,
columnNameMapping,
FieldType(UnknownType(Jsons.readTree(("""{"type": "potato"}"""))), nullable = true),
inputValue,
expectedValue,
expectedChangeReason,
)
}
}

View File

@@ -23,6 +23,7 @@ import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
import io.airbyte.cdk.load.data.TimestampWithTimezoneValue import io.airbyte.cdk.load.data.TimestampWithTimezoneValue
import io.airbyte.cdk.load.data.UnionType
import io.airbyte.cdk.load.data.UnknownType import io.airbyte.cdk.load.data.UnknownType
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
@@ -84,6 +85,18 @@ object TableOperationsFixtures {
"array" to FieldType(ArrayType(FieldType(StringType, true)), true), "array" to FieldType(ArrayType(FieldType(StringType, true)), true),
"object" to "object" to
FieldType(ObjectType(linkedMapOf("key" to FieldType(StringType, true))), true), FieldType(ObjectType(linkedMapOf("key" to FieldType(StringType, true))), true),
"union" to
FieldType(
UnionType(setOf(StringType, IntegerType), isLegacyUnion = false),
true
),
// Most destinations just ignore the isLegacyUnion flag, which is totally fine.
// This is here for the small set of connectors that respect it.
"legacy_union" to
FieldType(
UnionType(setOf(StringType, IntegerType), isLegacyUnion = true),
true
),
"unknown" to FieldType(UnknownType(Jsons.readTree("""{"type": "potato"}""")), true), "unknown" to FieldType(UnknownType(Jsons.readTree("""{"type": "potato"}""")), true),
), ),
) )
@@ -101,6 +114,8 @@ object TableOperationsFixtures {
"time_ntz" to "time_ntz", "time_ntz" to "time_ntz",
"array" to "array", "array" to "array",
"object" to "object", "object" to "object",
"union" to "union",
"legacy_union" to "legacy_union",
"unknown" to "unknown", "unknown" to "unknown",
) )
) )
@@ -714,6 +729,11 @@ object TableOperationsFixtures {
return map { record -> record.mapKeys { (k, _) -> totalMapping.invert()[k] ?: k } } return map { record -> record.mapKeys { (k, _) -> totalMapping.invert()[k] ?: k } }
} }
fun <V> List<Map<String, V>>.removeAirbyteColumns(
airbyteMetaColumnMapping: Map<String, String>
): List<Map<String, V>> =
this.map { rec -> rec.filter { !airbyteMetaColumnMapping.containsValue(it.key) } }
fun <V> List<Map<String, V>>.removeNulls() = fun <V> List<Map<String, V>>.removeNulls() =
this.map { record -> record.filterValues { it != null } } this.map { record -> record.filterValues { it != null } }

View File

@@ -58,7 +58,8 @@ interface TableOperationsSuite {
get() = Meta.COLUMN_NAMES.associateWith { it } get() = Meta.COLUMN_NAMES.associateWith { it }
private val harness: TableOperationsTestHarness private val harness: TableOperationsTestHarness
get() = TableOperationsTestHarness(client, testClient, airbyteMetaColumnMapping) get() =
TableOperationsTestHarness(client, testClient, schemaFactory, airbyteMetaColumnMapping)
/** Tests basic database connectivity by pinging the database. */ /** Tests basic database connectivity by pinging the database. */
fun `connect to database`() = runTest { assertDoesNotThrow { testClient.ping() } } fun `connect to database`() = runTest { assertDoesNotThrow { testClient.ping() } }
@@ -606,7 +607,7 @@ interface TableOperationsSuite {
val targetTableSchema = val targetTableSchema =
schemaFactory.make( schemaFactory.make(
targetTable, targetTable,
Fixtures.TEST_INTEGER_SCHEMA.properties, Fixtures.ID_TEST_WITH_CDC_SCHEMA.properties,
Dedupe( Dedupe(
primaryKey = listOf(listOf(Fixtures.ID_FIELD)), primaryKey = listOf(listOf(Fixtures.ID_FIELD)),
cursor = listOf(Fixtures.TEST_FIELD), cursor = listOf(Fixtures.TEST_FIELD),

View File

@@ -4,11 +4,24 @@
package io.airbyte.cdk.load.component package io.airbyte.cdk.load.component
import io.airbyte.cdk.load.command.Append
import io.airbyte.cdk.load.command.DestinationStream import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.component.TableOperationsFixtures.inputRecord
import io.airbyte.cdk.load.component.TableOperationsFixtures.insertRecords import io.airbyte.cdk.load.component.TableOperationsFixtures.insertRecords
import io.airbyte.cdk.load.component.TableOperationsFixtures.removeAirbyteColumns
import io.airbyte.cdk.load.component.TableOperationsFixtures.removeNulls
import io.airbyte.cdk.load.component.TableOperationsFixtures.reverseColumnNameMapping
import io.airbyte.cdk.load.data.AirbyteValue import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.EnrichedAirbyteValue
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.NullValue
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.dataflow.transform.ValidationResult
import io.airbyte.cdk.load.dataflow.transform.ValueCoercer
import io.airbyte.cdk.load.schema.TableSchemaFactory
import io.airbyte.cdk.load.schema.model.TableName import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.ColumnNameMapping import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Reason
import io.github.oshai.kotlinlogging.KotlinLogging import io.github.oshai.kotlinlogging.KotlinLogging
import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertEquals
@@ -21,6 +34,7 @@ private val log = KotlinLogging.logger {}
class TableOperationsTestHarness( class TableOperationsTestHarness(
private val client: TableOperationsClient, private val client: TableOperationsClient,
private val testClient: TestTableOperationsClient, private val testClient: TestTableOperationsClient,
private val schemaFactory: TableSchemaFactory,
private val airbyteMetaColumnMapping: Map<String, String>, private val airbyteMetaColumnMapping: Map<String, String>,
) { ) {
@@ -100,8 +114,77 @@ class TableOperationsTestHarness(
/** Reads records from a table, filtering out Meta columns. */ /** Reads records from a table, filtering out Meta columns. */
suspend fun readTableWithoutMetaColumns(tableName: TableName): List<Map<String, Any>> { suspend fun readTableWithoutMetaColumns(tableName: TableName): List<Map<String, Any>> {
val tableRead = testClient.readTable(tableName) val tableRead = testClient.readTable(tableName)
return tableRead.map { rec -> return tableRead.removeAirbyteColumns(airbyteMetaColumnMapping)
rec.filter { !airbyteMetaColumnMapping.containsValue(it.key) } }
/** Apply the coercer to a value and verify that we can write the coerced value correctly */
suspend fun testValueCoercion(
coercer: ValueCoercer,
columnNameMapping: ColumnNameMapping,
fieldType: FieldType,
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?,
) {
val testNamespace = TableOperationsFixtures.generateTestNamespace("test")
val tableName =
TableOperationsFixtures.generateTestTableName("table-test-table", testNamespace)
val schema = ObjectType(linkedMapOf("test" to fieldType))
val tableSchema = schemaFactory.make(tableName, schema.properties, Append)
val stream =
TableOperationsFixtures.createStream(
namespace = tableName.namespace,
name = tableName.name,
tableSchema = tableSchema,
)
val inputValueAsEnrichedAirbyteValue =
EnrichedAirbyteValue(
inputValue,
fieldType.type,
"test",
airbyteMetaField = null,
)
val validatedValue = coercer.validate(inputValueAsEnrichedAirbyteValue)
val valueToInsert: AirbyteValue
val changeReason: Reason?
when (validatedValue) {
is ValidationResult.ShouldNullify -> {
valueToInsert = NullValue
changeReason = validatedValue.reason
}
is ValidationResult.ShouldTruncate -> {
valueToInsert = validatedValue.truncatedValue
changeReason = validatedValue.reason
}
ValidationResult.Valid -> {
valueToInsert = inputValue
changeReason = null
}
} }
client.createNamespace(testNamespace)
client.createTable(stream, tableName, columnNameMapping, replace = false)
testClient.insertRecords(
tableName,
columnNameMapping,
inputRecord("test" to valueToInsert),
)
val actualRecords =
testClient
.readTable(tableName)
.removeAirbyteColumns(airbyteMetaColumnMapping)
.reverseColumnNameMapping(columnNameMapping, airbyteMetaColumnMapping)
.removeNulls()
val actualValue = actualRecords.first()["test"]
assertEquals(
expectedValue,
actualValue,
"For input $inputValue, expected ${expectedValue.simpleClassName()}; actual value was ${actualValue.simpleClassName()}. Coercer output was $validatedValue.",
)
assertEquals(expectedChangeReason, changeReason)
} }
} }
fun Any?.simpleClassName() = this?.let { it::class.simpleName } ?: "null"

View File

@@ -44,7 +44,13 @@ interface TableSchemaEvolutionSuite {
val schemaFactory: TableSchemaFactory val schemaFactory: TableSchemaFactory
private val harness: TableOperationsTestHarness private val harness: TableOperationsTestHarness
get() = TableOperationsTestHarness(opsClient, testClient, airbyteMetaColumnMapping) get() =
TableOperationsTestHarness(
opsClient,
testClient,
schemaFactory,
airbyteMetaColumnMapping
)
/** /**
* Test that the connector can correctly discover all of its own data types. This test creates a * Test that the connector can correctly discover all of its own data types. This test creates a

View File

@@ -1 +1 @@
version=0.1.87 version=0.1.91

View File

@@ -10,5 +10,6 @@ CONNECTOR_PATH_PREFIXES = {
"airbyte-integrations/connectors", "airbyte-integrations/connectors",
"docs/integrations/sources", "docs/integrations/sources",
"docs/integrations/destinations", "docs/integrations/destinations",
"docs/ai-agents/connectors",
} }
MERGE_METHOD = "squash" MERGE_METHOD = "squash"

View File

@@ -75,7 +75,7 @@ This will copy the specified connector version to your development bucket. This
_💡 Note: A prerequisite is you have [gsutil](https://cloud.google.com/storage/docs/gsutil) installed and have run `gsutil auth login`_ _💡 Note: A prerequisite is you have [gsutil](https://cloud.google.com/storage/docs/gsutil) installed and have run `gsutil auth login`_
```bash ```bash
TARGET_BUCKET=<YOUR-DEV_BUCKET> CONNECTOR="airbyte/source-stripe" VERSION="3.17.0-dev.ea013c8741" poetry run poe copy-connector-from-prod TARGET_BUCKET=<YOUR-DEV_BUCKET> CONNECTOR="airbyte/source-stripe" VERSION="3.17.0-preview.ea013c8" poetry run poe copy-connector-from-prod
``` ```
### Promote Connector Version to Latest ### Promote Connector Version to Latest
@@ -87,5 +87,5 @@ _💡 Note: A prerequisite is you have [gsutil](https://cloud.google.com/storage
_⚠️ Warning: Its important to know that this will remove ANY existing files in the latest folder that are not in the versioned folder as it calls `gsutil rsync` with `-d` enabled._ _⚠️ Warning: Its important to know that this will remove ANY existing files in the latest folder that are not in the versioned folder as it calls `gsutil rsync` with `-d` enabled._
```bash ```bash
TARGET_BUCKET=<YOUR-DEV_BUCKET> CONNECTOR="airbyte/source-stripe" VERSION="3.17.0-dev.ea013c8741" poetry run poe promote-connector-to-latest TARGET_BUCKET=<YOUR-DEV_BUCKET> CONNECTOR="airbyte/source-stripe" VERSION="3.17.0-preview.ea013c8" poetry run poe promote-connector-to-latest
``` ```

View File

@@ -28,8 +28,8 @@ def get_docker_hub_auth_token() -> str:
def get_docker_hub_headers() -> Dict | None: def get_docker_hub_headers() -> Dict | None:
if "DOCKER_HUB_USERNAME" not in os.environ or "DOCKER_HUB_PASSWORD" not in os.environ: if not os.environ.get("DOCKER_HUB_USERNAME") or not os.environ.get("DOCKER_HUB_PASSWORD"):
# If the Docker Hub credentials are not provided, we can only anonymously call the Docker Hub API. # If the Docker Hub credentials are not provided (or are empty), we can only anonymously call the Docker Hub API.
# This will only work for public images and lead to a lower rate limit. # This will only work for public images and lead to a lower rate limit.
return {} return {}
else: else:

View File

@@ -434,7 +434,7 @@ def generate_and_persist_registry_entry(
bucket_name (str): The name of the GCS bucket. bucket_name (str): The name of the GCS bucket.
repo_metadata_file_path (pathlib.Path): The path to the spec file. repo_metadata_file_path (pathlib.Path): The path to the spec file.
registry_type (str): The registry type. registry_type (str): The registry type.
docker_image_tag (str): The docker image tag associated with this release. Typically a semver string (e.g. '1.2.3'), possibly with a suffix (e.g. '1.2.3-dev.abcde12345') docker_image_tag (str): The docker image tag associated with this release. Typically a semver string (e.g. '1.2.3'), possibly with a suffix (e.g. '1.2.3-preview.abcde12')
is_prerelease (bool): Whether this is a prerelease, or a main release. is_prerelease (bool): Whether this is a prerelease, or a main release.
""" """
# Read the repo metadata dict to bootstrap ourselves. We need the docker repository, # Read the repo metadata dict to bootstrap ourselves. We need the docker repository,
@@ -444,7 +444,7 @@ def generate_and_persist_registry_entry(
try: try:
# Now that we have the docker repo, read the appropriate versioned metadata from GCS. # Now that we have the docker repo, read the appropriate versioned metadata from GCS.
# This metadata will differ in a few fields (e.g. in prerelease mode, dockerImageTag will contain the actual prerelease tag `1.2.3-dev.abcde12345`), # This metadata will differ in a few fields (e.g. in prerelease mode, dockerImageTag will contain the actual prerelease tag `1.2.3-preview.abcde12`),
# so we'll treat this as the source of truth (ish. See below for how we handle the registryOverrides field.) # so we'll treat this as the source of truth (ish. See below for how we handle the registryOverrides field.)
gcs_client = get_gcs_storage_client(gcs_creds=os.environ.get("GCS_CREDENTIALS")) gcs_client = get_gcs_storage_client(gcs_creds=os.environ.get("GCS_CREDENTIALS"))
bucket = gcs_client.bucket(bucket_name) bucket = gcs_client.bucket(bucket_name)
@@ -533,7 +533,9 @@ def generate_and_persist_registry_entry(
# For latest versions that are disabled, delete any existing registry entry to remove it from the registry # For latest versions that are disabled, delete any existing registry entry to remove it from the registry
if ( if (
"-rc" not in metadata_dict["data"]["dockerImageTag"] and "-dev" not in metadata_dict["data"]["dockerImageTag"] "-rc" not in metadata_dict["data"]["dockerImageTag"]
and "-dev" not in metadata_dict["data"]["dockerImageTag"]
and "-preview" not in metadata_dict["data"]["dockerImageTag"]
) and not metadata_dict["data"]["registryOverrides"][registry_type]["enabled"]: ) and not metadata_dict["data"]["registryOverrides"][registry_type]["enabled"]:
logger.info( logger.info(
f"{registry_type} is not enabled: deleting existing {registry_type} registry entry for {metadata_dict['data']['dockerRepository']} at latest path." f"{registry_type} is not enabled: deleting existing {registry_type} registry entry for {metadata_dict['data']['dockerRepository']} at latest path."

View File

@@ -5,7 +5,7 @@ data:
connectorType: source connectorType: source
dockerRepository: airbyte/image-exists-1 dockerRepository: airbyte/image-exists-1
githubIssueLabel: source-alloydb-strict-encrypt githubIssueLabel: source-alloydb-strict-encrypt
dockerImageTag: 2.0.0-dev.cf3628ccf3 dockerImageTag: 2.0.0-preview.cf3628c
documentationUrl: https://docs.airbyte.com/integrations/sources/existingsource documentationUrl: https://docs.airbyte.com/integrations/sources/existingsource
connectorSubtype: database connectorSubtype: database
releaseStage: generally_available releaseStage: generally_available

View File

@@ -231,7 +231,7 @@ def test_upload_prerelease(mocker, valid_metadata_yaml_files, tmp_path):
mocker.patch.object(commands.click, "secho") mocker.patch.object(commands.click, "secho")
mocker.patch.object(commands, "upload_metadata_to_gcs") mocker.patch.object(commands, "upload_metadata_to_gcs")
prerelease_tag = "0.3.0-dev.6d33165120" prerelease_tag = "0.3.0-preview.6d33165"
bucket = "my-bucket" bucket = "my-bucket"
metadata_file_path = valid_metadata_yaml_files[0] metadata_file_path = valid_metadata_yaml_files[0]
validator_opts = ValidatorOptions(docs_path=str(tmp_path), prerelease_tag=prerelease_tag) validator_opts = ValidatorOptions(docs_path=str(tmp_path), prerelease_tag=prerelease_tag)

View File

@@ -582,7 +582,7 @@ def test_upload_metadata_to_gcs_invalid_docker_images(mocker, invalid_metadata_u
def test_upload_metadata_to_gcs_with_prerelease(mocker, valid_metadata_upload_files, tmp_path): def test_upload_metadata_to_gcs_with_prerelease(mocker, valid_metadata_upload_files, tmp_path):
mocker.spy(gcs_upload, "_file_upload") mocker.spy(gcs_upload, "_file_upload")
mocker.spy(gcs_upload, "upload_file_if_changed") mocker.spy(gcs_upload, "upload_file_if_changed")
prerelease_image_tag = "1.5.6-dev.f80318f754" prerelease_image_tag = "1.5.6-preview.f80318f"
for valid_metadata_upload_file in valid_metadata_upload_files: for valid_metadata_upload_file in valid_metadata_upload_files:
tmp_metadata_file_path = tmp_path / "metadata.yaml" tmp_metadata_file_path = tmp_path / "metadata.yaml"
@@ -701,7 +701,7 @@ def test_upload_metadata_to_gcs_release_candidate(mocker, get_fixture_path, tmp_
) )
assert metadata.data.releases.rolloutConfiguration.enableProgressiveRollout assert metadata.data.releases.rolloutConfiguration.enableProgressiveRollout
prerelease_tag = "1.5.6-dev.f80318f754" if prerelease else None prerelease_tag = "1.5.6-preview.f80318f" if prerelease else None
upload_info = gcs_upload.upload_metadata_to_gcs( upload_info = gcs_upload.upload_metadata_to_gcs(
"my_bucket", "my_bucket",

View File

@@ -110,14 +110,14 @@ class PublishConnectorContext(ConnectorContext):
@property @property
def pre_release_suffix(self) -> str: def pre_release_suffix(self) -> str:
return self.git_revision[:10] return self.git_revision[:7]
@property @property
def docker_image_tag(self) -> str: def docker_image_tag(self) -> str:
# get the docker image tag from the parent class # get the docker image tag from the parent class
metadata_tag = super().docker_image_tag metadata_tag = super().docker_image_tag
if self.pre_release: if self.pre_release:
return f"{metadata_tag}-dev.{self.pre_release_suffix}" return f"{metadata_tag}-preview.{self.pre_release_suffix}"
else: else:
return metadata_tag return metadata_tag

View File

@@ -25,7 +25,7 @@ from pipelines.helpers.utils import raise_if_not_user
from pipelines.models.steps import STEP_PARAMS, Step, StepResult from pipelines.models.steps import STEP_PARAMS, Step, StepResult
# Pin the PyAirbyte version to avoid updates from breaking CI # Pin the PyAirbyte version to avoid updates from breaking CI
PYAIRBYTE_VERSION = "0.20.2" PYAIRBYTE_VERSION = "0.35.1"
class PytestStep(Step, ABC): class PytestStep(Step, ABC):

View File

@@ -156,7 +156,8 @@ class TestPyAirbyteValidationTests:
result = await PyAirbyteValidation(context_for_valid_connector)._run(mocker.MagicMock()) result = await PyAirbyteValidation(context_for_valid_connector)._run(mocker.MagicMock())
assert isinstance(result, StepResult) assert isinstance(result, StepResult)
assert result.status == StepStatus.SUCCESS assert result.status == StepStatus.SUCCESS
assert "Getting `spec` output from connector..." in result.stdout # Verify the connector name appears in output (stable across PyAirbyte versions)
assert context_for_valid_connector.connector.technical_name in (result.stdout + result.stderr)
async def test__run_validation_skip_unpublished_connector( async def test__run_validation_skip_unpublished_connector(
self, self,

View File

@@ -1,2 +1,2 @@
cdkVersion=0.1.86 cdkVersion=0.1.89
JunitMethodExecutionTimeout=10m JunitMethodExecutionTimeout=10m

View File

@@ -2,7 +2,7 @@ data:
connectorSubtype: database connectorSubtype: database
connectorType: destination connectorType: destination
definitionId: ce0d828e-1dc4-496c-b122-2da42e637e48 definitionId: ce0d828e-1dc4-496c-b122-2da42e637e48
dockerImageTag: 2.1.16-rc.2 dockerImageTag: 2.1.18
dockerRepository: airbyte/destination-clickhouse dockerRepository: airbyte/destination-clickhouse
githubIssueLabel: destination-clickhouse githubIssueLabel: destination-clickhouse
icon: clickhouse.svg icon: clickhouse.svg
@@ -27,7 +27,7 @@ data:
releaseStage: generally_available releaseStage: generally_available
releases: releases:
rolloutConfiguration: rolloutConfiguration:
enableProgressiveRollout: true enableProgressiveRollout: false
breakingChanges: breakingChanges:
2.0.0: 2.0.0:
message: "This connector has been re-written from scratch. Data will now be typed and stored in final (non-raw) tables. The connector may require changes to its configuration to function properly and downstream pipelines may be affected. Warning: SSH tunneling is in Beta." message: "This connector has been re-written from scratch. Data will now be typed and stored in final (non-raw) tables. The connector may require changes to its configuration to function properly and downstream pipelines may be affected. Warning: SSH tunneling is in Beta."

View File

@@ -54,8 +54,11 @@ class ClickhouseSqlGenerator {
// Check if cursor column type is valid for ClickHouse ReplacingMergeTree // Check if cursor column type is valid for ClickHouse ReplacingMergeTree
val cursor = tableSchema.getCursor().firstOrNull() val cursor = tableSchema.getCursor().firstOrNull()
val cursorType = cursor?.let { finalSchema[it]?.type } val cursorType = cursor?.let { finalSchema[it]?.type }
val useCursorAsVersion =
cursorType != null && isValidVersionColumn(cursor, cursorType)
val versionColumn = val versionColumn =
if (cursorType?.isValidVersionColumnType() ?: false) { if (useCursorAsVersion) {
"`$cursor`" "`$cursor`"
} else { } else {
// Fallback to _airbyte_extracted_at if no cursor is specified or cursor // Fallback to _airbyte_extracted_at if no cursor is specified or cursor

View File

@@ -4,6 +4,7 @@
package io.airbyte.integrations.destination.clickhouse.client package io.airbyte.integrations.destination.clickhouse.client
import io.airbyte.cdk.load.table.CDC_CURSOR_COLUMN
import io.airbyte.integrations.destination.clickhouse.client.ClickhouseSqlTypes.VALID_VERSION_COLUMN_TYPES import io.airbyte.integrations.destination.clickhouse.client.ClickhouseSqlTypes.VALID_VERSION_COLUMN_TYPES
object ClickhouseSqlTypes { object ClickhouseSqlTypes {
@@ -23,4 +24,9 @@ object ClickhouseSqlTypes {
) )
} }
fun String.isValidVersionColumnType() = VALID_VERSION_COLUMN_TYPES.contains(this) // Warning: if any munging changes the name of the CDC column name this will break.
// Currently, that is not the case.
fun isValidVersionColumn(name: String, type: String) =
// CDC cursors cannot be used as a version column since they are null
// during the initial CDC snapshot.
name != CDC_CURSOR_COLUMN && VALID_VERSION_COLUMN_TYPES.contains(type)

View File

@@ -1,62 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.clickhouse.config
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.data.Transformations.Companion.toAlphanumericAndUnderscore
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.ColumnNameGenerator
import io.airbyte.cdk.load.table.FinalTableNameGenerator
import io.airbyte.integrations.destination.clickhouse.spec.ClickhouseConfiguration
import jakarta.inject.Singleton
import java.util.Locale
import java.util.UUID
@Singleton
class ClickhouseFinalTableNameGenerator(private val config: ClickhouseConfiguration) :
FinalTableNameGenerator {
override fun getTableName(streamDescriptor: DestinationStream.Descriptor) =
TableName(
namespace =
(streamDescriptor.namespace ?: config.resolvedDatabase)
.toClickHouseCompatibleName(),
name = streamDescriptor.name.toClickHouseCompatibleName(),
)
}
@Singleton
class ClickhouseColumnNameGenerator : ColumnNameGenerator {
override fun getColumnName(column: String): ColumnNameGenerator.ColumnName {
return ColumnNameGenerator.ColumnName(
column.toClickHouseCompatibleName(),
column.lowercase(Locale.getDefault()).toClickHouseCompatibleName(),
)
}
}
/**
* Transforms a string to be compatible with ClickHouse table and column names.
*
* @return The transformed string suitable for ClickHouse identifiers.
*/
fun String.toClickHouseCompatibleName(): String {
// 1. Replace any character that is not a letter,
// a digit (0-9), or an underscore (_) with a single underscore.
var transformed = toAlphanumericAndUnderscore(this)
// 2. Ensure the identifier does not start with a digit.
// If it starts with a digit, prepend an underscore.
if (transformed.isNotEmpty() && transformed[0].isDigit()) {
transformed = "_$transformed"
}
// 3.Do not allow empty strings.
if (transformed.isEmpty()) {
return "default_name_${UUID.randomUUID()}" // A fallback name if the input results in an
// empty string
}
return transformed
}

View File

@@ -0,0 +1,33 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.clickhouse.schema
import io.airbyte.cdk.load.data.Transformations.Companion.toAlphanumericAndUnderscore
import java.util.UUID
/**
* Transforms a string to be compatible with ClickHouse table and column names.
*
* @return The transformed string suitable for ClickHouse identifiers.
*/
fun String.toClickHouseCompatibleName(): String {
// 1. Replace any character that is not a letter,
// a digit (0-9), or an underscore (_) with a single underscore.
var transformed = toAlphanumericAndUnderscore(this)
// 2.Do not allow empty strings.
if (transformed.isEmpty()) {
return "default_name_${UUID.randomUUID()}" // A fallback name if the input results in an
// empty string
}
// 3. Ensure the identifier does not start with a digit.
// If it starts with a digit, prepend an underscore.
if (transformed[0].isDigit()) {
transformed = "_$transformed"
}
return transformed
}

View File

@@ -29,8 +29,7 @@ import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.schema.model.TableName import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.TempTableNameGenerator import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.integrations.destination.clickhouse.client.ClickhouseSqlTypes import io.airbyte.integrations.destination.clickhouse.client.ClickhouseSqlTypes
import io.airbyte.integrations.destination.clickhouse.client.isValidVersionColumnType import io.airbyte.integrations.destination.clickhouse.client.isValidVersionColumn
import io.airbyte.integrations.destination.clickhouse.config.toClickHouseCompatibleName
import io.airbyte.integrations.destination.clickhouse.spec.ClickhouseConfiguration import io.airbyte.integrations.destination.clickhouse.spec.ClickhouseConfiguration
import jakarta.inject.Singleton import jakarta.inject.Singleton
@@ -100,7 +99,7 @@ class ClickhouseTableSchemaMapper(
if (cursor != null) { if (cursor != null) {
// Check if the cursor column type is valid for ClickHouse ReplacingMergeTree // Check if the cursor column type is valid for ClickHouse ReplacingMergeTree
val cursorColumnType = tableSchema.columnSchema.finalSchema[cursor]!!.type val cursorColumnType = tableSchema.columnSchema.finalSchema[cursor]!!.type
if (cursorColumnType.isValidVersionColumnType()) { if (isValidVersionColumn(cursor, cursorColumnType)) {
// Cursor column is valid, use it as version column // Cursor column is valid, use it as version column
add(cursor) // Make cursor column non-nullable too add(cursor) // Make cursor column non-nullable too
} }

View File

@@ -0,0 +1,77 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.clickhouse.component
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures.NEGATIVE_HIGH_PRECISION_FLOAT
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures.POSITIVE_HIGH_PRECISION_FLOAT
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures.SMALLEST_NEGATIVE_FLOAT32
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures.SMALLEST_NEGATIVE_FLOAT64
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures.SMALLEST_POSITIVE_FLOAT32
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures.SMALLEST_POSITIVE_FLOAT64
import io.airbyte.cdk.load.component.DataCoercionSuite
import io.airbyte.cdk.load.component.TableOperationsClient
import io.airbyte.cdk.load.component.TestTableOperationsClient
import io.airbyte.cdk.load.component.toArgs
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.dataflow.transform.ValueCoercer
import io.airbyte.cdk.load.schema.TableSchemaFactory
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Reason
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.MethodSource
@MicronautTest(environments = ["component"], resolveParameters = false)
class ClickhouseDataCoercionTest(
override val coercer: ValueCoercer,
override val opsClient: TableOperationsClient,
override val testClient: TestTableOperationsClient,
override val schemaFactory: TableSchemaFactory,
) : DataCoercionSuite {
@ParameterizedTest
// We use clickhouse's Int64 type for integers
@MethodSource("io.airbyte.cdk.load.component.DataCoercionIntegerFixtures#int64")
override fun `handle integer values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) {
super.`handle integer values`(inputValue, expectedValue, expectedChangeReason)
}
@ParameterizedTest
@MethodSource(
"io.airbyte.integrations.destination.clickhouse.component.ClickhouseDataCoercionTest#numbers"
)
override fun `handle number values`(
inputValue: AirbyteValue,
expectedValue: Any?,
expectedChangeReason: Reason?
) {
super.`handle number values`(inputValue, expectedValue, expectedChangeReason)
}
companion object {
/**
* destination-clickhouse doesn't set a change reason when truncating high-precision numbers
* (https://github.com/airbytehq/airbyte-internal-issues/issues/15401)
*/
@JvmStatic
fun numbers() =
DataCoercionNumberFixtures.numeric38_9
.map {
when (it.name) {
POSITIVE_HIGH_PRECISION_FLOAT,
NEGATIVE_HIGH_PRECISION_FLOAT,
SMALLEST_POSITIVE_FLOAT32,
SMALLEST_NEGATIVE_FLOAT32,
SMALLEST_POSITIVE_FLOAT64,
SMALLEST_NEGATIVE_FLOAT64 -> it.copy(changeReason = null)
else -> it
}
}
.toArgs()
}
}

View File

@@ -22,7 +22,7 @@ class ClickhouseTableSchemaEvolutionTest(
override val client: TableSchemaEvolutionClient, override val client: TableSchemaEvolutionClient,
override val opsClient: TableOperationsClient, override val opsClient: TableOperationsClient,
override val testClient: TestTableOperationsClient, override val testClient: TestTableOperationsClient,
override val schemaFactory: TableSchemaFactory override val schemaFactory: TableSchemaFactory,
) : TableSchemaEvolutionSuite { ) : TableSchemaEvolutionSuite {
private val allTypesTableSchema = private val allTypesTableSchema =
TableSchema( TableSchema(

View File

@@ -16,7 +16,7 @@ import io.airbyte.cdk.load.data.TimestampWithTimezoneValue
import io.airbyte.cdk.load.data.TimestampWithoutTimezoneValue import io.airbyte.cdk.load.data.TimestampWithoutTimezoneValue
import io.airbyte.cdk.load.test.util.ExpectedRecordMapper import io.airbyte.cdk.load.test.util.ExpectedRecordMapper
import io.airbyte.cdk.load.test.util.OutputRecord import io.airbyte.cdk.load.test.util.OutputRecord
import io.airbyte.integrations.destination.clickhouse.config.toClickHouseCompatibleName import io.airbyte.integrations.destination.clickhouse.schema.toClickHouseCompatibleName
import java.math.RoundingMode import java.math.RoundingMode
import java.time.LocalTime import java.time.LocalTime
import java.time.ZoneOffset import java.time.ZoneOffset

View File

@@ -30,8 +30,8 @@ import io.airbyte.cdk.load.write.UnknownTypesBehavior
import io.airbyte.integrations.destination.clickhouse.ClickhouseConfigUpdater import io.airbyte.integrations.destination.clickhouse.ClickhouseConfigUpdater
import io.airbyte.integrations.destination.clickhouse.ClickhouseContainerHelper import io.airbyte.integrations.destination.clickhouse.ClickhouseContainerHelper
import io.airbyte.integrations.destination.clickhouse.Utils import io.airbyte.integrations.destination.clickhouse.Utils
import io.airbyte.integrations.destination.clickhouse.config.toClickHouseCompatibleName
import io.airbyte.integrations.destination.clickhouse.fixtures.ClickhouseExpectedRecordMapper import io.airbyte.integrations.destination.clickhouse.fixtures.ClickhouseExpectedRecordMapper
import io.airbyte.integrations.destination.clickhouse.schema.toClickHouseCompatibleName
import io.airbyte.integrations.destination.clickhouse.spec.ClickhouseConfiguration import io.airbyte.integrations.destination.clickhouse.spec.ClickhouseConfiguration
import io.airbyte.integrations.destination.clickhouse.spec.ClickhouseConfigurationFactory import io.airbyte.integrations.destination.clickhouse.spec.ClickhouseConfigurationFactory
import io.airbyte.integrations.destination.clickhouse.spec.ClickhouseSpecificationOss import io.airbyte.integrations.destination.clickhouse.spec.ClickhouseSpecificationOss

View File

@@ -21,7 +21,6 @@ import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.schema.model.TableName import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.ColumnNameMapping import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TempTableNameGenerator import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.integrations.destination.clickhouse.config.ClickhouseFinalTableNameGenerator
import io.mockk.coEvery import io.mockk.coEvery
import io.mockk.coVerify import io.mockk.coVerify
import io.mockk.coVerifyOrder import io.mockk.coVerifyOrder
@@ -39,8 +38,6 @@ class ClickhouseAirbyteClientTest {
// Mocks // Mocks
private val client: ClickHouseClientRaw = mockk(relaxed = true) private val client: ClickHouseClientRaw = mockk(relaxed = true)
private val clickhouseSqlGenerator: ClickhouseSqlGenerator = mockk(relaxed = true) private val clickhouseSqlGenerator: ClickhouseSqlGenerator = mockk(relaxed = true)
private val clickhouseFinalTableNameGenerator: ClickhouseFinalTableNameGenerator =
mockk(relaxed = true)
private val tempTableNameGenerator: TempTableNameGenerator = mockk(relaxed = true) private val tempTableNameGenerator: TempTableNameGenerator = mockk(relaxed = true)
// Client // Client
@@ -105,7 +102,6 @@ class ClickhouseAirbyteClientTest {
alterTableStatement alterTableStatement
coEvery { clickhouseAirbyteClient.execute(alterTableStatement) } returns coEvery { clickhouseAirbyteClient.execute(alterTableStatement) } returns
mockk(relaxed = true) mockk(relaxed = true)
every { clickhouseFinalTableNameGenerator.getTableName(any()) } returns mockTableName
mockCHSchemaWithAirbyteColumns() mockCHSchemaWithAirbyteColumns()
@@ -172,7 +168,6 @@ class ClickhouseAirbyteClientTest {
coEvery { clickhouseAirbyteClient.execute(any()) } returns mockk(relaxed = true) coEvery { clickhouseAirbyteClient.execute(any()) } returns mockk(relaxed = true)
every { tempTableNameGenerator.generate(any()) } returns tempTableName every { tempTableNameGenerator.generate(any()) } returns tempTableName
every { clickhouseFinalTableNameGenerator.getTableName(any()) } returns finalTableName
mockCHSchemaWithAirbyteColumns() mockCHSchemaWithAirbyteColumns()
@@ -226,8 +221,6 @@ class ClickhouseAirbyteClientTest {
fun `test ensure schema matches fails if no airbyte columns`() = runTest { fun `test ensure schema matches fails if no airbyte columns`() = runTest {
val finalTableName = TableName("fin", "al") val finalTableName = TableName("fin", "al")
every { clickhouseFinalTableNameGenerator.getTableName(any()) } returns finalTableName
val columnMapping = ColumnNameMapping(mapOf()) val columnMapping = ColumnNameMapping(mapOf())
val stream = val stream =
mockk<DestinationStream> { mockk<DestinationStream> {

View File

@@ -2,13 +2,13 @@
* Copyright (c) 2025 Airbyte, Inc., all rights reserved. * Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/ */
package io.airbyte.integrations.destination.clickhouse.config package io.airbyte.integrations.destination.clickhouse.schema
import java.util.UUID import java.util.UUID
import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
class ClickhouseNameGeneratorTest { class ClickhouseNamingUtilsTest {
@Test @Test
fun `toClickHouseCompatibleName replaces special characters with underscores`() { fun `toClickHouseCompatibleName replaces special characters with underscores`() {
Assertions.assertEquals("hello_world", "hello world".toClickHouseCompatibleName()) Assertions.assertEquals("hello_world", "hello world".toClickHouseCompatibleName())

View File

@@ -6,7 +6,7 @@ data:
connectorSubtype: database connectorSubtype: database
connectorType: destination connectorType: destination
definitionId: 25c5221d-dce2-4163-ade9-739ef790f503 definitionId: 25c5221d-dce2-4163-ade9-739ef790f503
dockerImageTag: 3.0.5-rc.1 dockerImageTag: 3.0.5
dockerRepository: airbyte/destination-postgres dockerRepository: airbyte/destination-postgres
documentationUrl: https://docs.airbyte.com/integrations/destinations/postgres documentationUrl: https://docs.airbyte.com/integrations/destinations/postgres
githubIssueLabel: destination-postgres githubIssueLabel: destination-postgres
@@ -22,7 +22,7 @@ data:
enabled: true enabled: true
releases: releases:
rolloutConfiguration: rolloutConfiguration:
enableProgressiveRollout: true enableProgressiveRollout: false
breakingChanges: breakingChanges:
3.0.0: 3.0.0:
message: > message: >

View File

@@ -4,12 +4,16 @@
package io.airbyte.integrations.destination.postgres.client package io.airbyte.integrations.destination.postgres.client
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
import io.airbyte.cdk.ConfigErrorException
import io.airbyte.cdk.load.command.DestinationStream import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.component.ColumnChangeset import io.airbyte.cdk.load.component.ColumnChangeset
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.component.TableColumns import io.airbyte.cdk.load.component.TableColumns
import io.airbyte.cdk.load.component.TableOperationsClient import io.airbyte.cdk.load.component.TableOperationsClient
import io.airbyte.cdk.load.component.TableSchema import io.airbyte.cdk.load.component.TableSchema
import io.airbyte.cdk.load.component.TableSchemaEvolutionClient import io.airbyte.cdk.load.component.TableSchemaEvolutionClient
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAMES
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.schema.model.TableName import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.ColumnNameMapping import io.airbyte.cdk.load.table.ColumnNameMapping
@@ -26,6 +30,11 @@ import javax.sql.DataSource
private val log = KotlinLogging.logger {} private val log = KotlinLogging.logger {}
@Singleton @Singleton
@SuppressFBWarnings(
value = ["SQL_NONCONSTANT_STRING_PASSED_TO_EXECUTE"],
justification =
"There is little chance of SQL injection. There is also little need for statement reuse. The basic statement is more readable than the prepared statement."
)
class PostgresAirbyteClient( class PostgresAirbyteClient(
private val dataSource: DataSource, private val dataSource: DataSource,
private val sqlGenerator: PostgresDirectLoadSqlGenerator, private val sqlGenerator: PostgresDirectLoadSqlGenerator,
@@ -53,6 +62,29 @@ class PostgresAirbyteClient(
null null
} }
override suspend fun namespaceExists(namespace: String): Boolean {
return executeQuery(
"""
SELECT EXISTS(
SELECT 1 FROM information_schema.schemata
WHERE schema_name = '$namespace'
)
"""
) { rs -> rs.next() && rs.getBoolean(1) }
}
override suspend fun tableExists(table: TableName): Boolean {
return executeQuery(
"""
SELECT EXISTS(
SELECT 1 FROM information_schema.tables
WHERE table_schema = '${table.namespace}'
AND table_name = '${table.name}'
)
"""
) { rs -> rs.next() && rs.getBoolean(1) }
}
override suspend fun createNamespace(namespace: String) { override suspend fun createNamespace(namespace: String) {
try { try {
execute(sqlGenerator.createNamespace(namespace)) execute(sqlGenerator.createNamespace(namespace))
@@ -171,14 +203,26 @@ class PostgresAirbyteClient(
} }
override suspend fun discoverSchema(tableName: TableName): TableSchema { override suspend fun discoverSchema(tableName: TableName): TableSchema {
TODO("Not yet implemented") val columnsInDb = getColumnsFromDbForDiscovery(tableName)
val hasAllAirbyteColumns = columnsInDb.keys.containsAll(COLUMN_NAMES)
if (!hasAllAirbyteColumns) {
val message =
"The target table ($tableName) already exists in the destination, but does not contain Airbyte's internal columns. Airbyte can only sync to Airbyte-controlled tables. To fix this error, you must either delete the target table or add a prefix in the connection configuration in order to sync to a separate table in the destination."
log.error { message }
throw ConfigErrorException(message)
}
// Filter out Airbyte columns
val userColumns = columnsInDb.filterKeys { it !in COLUMN_NAMES }
return TableSchema(userColumns)
} }
override fun computeSchema( override fun computeSchema(
stream: DestinationStream, stream: DestinationStream,
columnNameMapping: ColumnNameMapping columnNameMapping: ColumnNameMapping
): TableSchema { ): TableSchema {
TODO("Not yet implemented") return TableSchema(stream.tableSchema.columnSchema.finalSchema)
} }
override suspend fun applyChangeset( override suspend fun applyChangeset(
@@ -188,9 +232,73 @@ class PostgresAirbyteClient(
expectedColumns: TableColumns, expectedColumns: TableColumns,
columnChangeset: ColumnChangeset columnChangeset: ColumnChangeset
) { ) {
TODO("Not yet implemented") if (
columnChangeset.columnsToAdd.isNotEmpty() ||
columnChangeset.columnsToDrop.isNotEmpty() ||
columnChangeset.columnsToChange.isNotEmpty()
) {
log.info { "Summary of the table alterations:" }
log.info { "Added columns: ${columnChangeset.columnsToAdd}" }
log.info { "Deleted columns: ${columnChangeset.columnsToDrop}" }
log.info { "Modified columns: ${columnChangeset.columnsToChange}" }
// Convert from TableColumns format to Column format
val columnsToAdd =
columnChangeset.columnsToAdd
.map { (name, type) -> Column(name, type.type, type.nullable) }
.toSet()
val columnsToRemove =
columnChangeset.columnsToDrop
.map { (name, type) -> Column(name, type.type, type.nullable) }
.toSet()
val columnsToModify =
columnChangeset.columnsToChange
.map { (name, change) ->
Column(name, change.newType.type, change.newType.nullable)
}
.toSet()
val columnsInDb =
(columnChangeset.columnsToRetain +
columnChangeset.columnsToDrop +
columnChangeset.columnsToChange.mapValues { it.value.originalType })
.map { (name, type) -> Column(name, type.type, type.nullable) }
.toSet()
execute(
sqlGenerator.matchSchemas(
tableName = tableName,
columnsToAdd = columnsToAdd,
columnsToRemove = columnsToRemove,
columnsToModify = columnsToModify,
columnsInDb = columnsInDb,
recreatePrimaryKeyIndex = false,
primaryKeyColumnNames = emptyList(),
recreateCursorIndex = false,
cursorColumnName = null,
)
)
}
} }
/**
* Gets columns from the database including their types for schema discovery. Unlike
* [getColumnsFromDb], this returns all columns including Airbyte metadata columns.
*/
private fun getColumnsFromDbForDiscovery(tableName: TableName): Map<String, ColumnType> =
executeQuery(sqlGenerator.getTableSchema(tableName)) { rs ->
val columnsInDb: MutableMap<String, ColumnType> = mutableMapOf()
while (rs.next()) {
val columnName = rs.getString(COLUMN_NAME_COLUMN)
val dataType = rs.getString("data_type")
// PostgreSQL's information_schema always returns 'YES' or 'NO' for is_nullable
val isNullable = rs.getString("is_nullable") == "YES"
columnsInDb[columnName] = ColumnType(normalizePostgresType(dataType), isNullable)
}
columnsInDb
}
/** /**
* Checks if the primary key index matches the current stream configuration. If the primary keys * Checks if the primary key index matches the current stream configuration. If the primary keys
* have changed (detected by comparing columns in the index), then this will return true, * have changed (detected by comparing columns in the index), then this will return true,

View File

@@ -531,7 +531,7 @@ class PostgresDirectLoadSqlGenerator(
fun getTableSchema(tableName: TableName): String = fun getTableSchema(tableName: TableName): String =
""" """
SELECT column_name, data_type SELECT column_name, data_type, is_nullable
FROM information_schema.columns FROM information_schema.columns
WHERE table_schema = '${tableName.namespace}' WHERE table_schema = '${tableName.namespace}'
AND table_name = '${tableName.name}'; AND table_name = '${tableName.name}';

View File

@@ -49,6 +49,7 @@ class PostgresWriter(
override fun createStreamLoader(stream: DestinationStream): StreamLoader { override fun createStreamLoader(stream: DestinationStream): StreamLoader {
val initialStatus = initialStatuses[stream]!! val initialStatus = initialStatuses[stream]!!
val realTableName = stream.tableSchema.tableNames.finalTableName!! val realTableName = stream.tableSchema.tableNames.finalTableName!!
val tempTableName = tempTableNameGenerator.generate(realTableName) val tempTableName = tempTableNameGenerator.generate(realTableName)
val columnNameMapping = val columnNameMapping =
ColumnNameMapping(stream.tableSchema.columnSchema.inputToFinalColumnNames) ColumnNameMapping(stream.tableSchema.columnSchema.inputToFinalColumnNames)

View File

@@ -0,0 +1,45 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.postgres.component
import io.airbyte.cdk.load.util.Jsons
import io.airbyte.integrations.destination.postgres.PostgresConfigUpdater
import io.airbyte.integrations.destination.postgres.PostgresContainerHelper
import io.airbyte.integrations.destination.postgres.spec.PostgresConfiguration
import io.airbyte.integrations.destination.postgres.spec.PostgresConfigurationFactory
import io.airbyte.integrations.destination.postgres.spec.PostgresSpecificationOss
import io.micronaut.context.annotation.Factory
import io.micronaut.context.annotation.Primary
import io.micronaut.context.annotation.Requires
import jakarta.inject.Singleton
@Requires(env = ["component"])
@Factory
class PostgresComponentTestConfigFactory {
@Singleton
@Primary
fun config(): PostgresConfiguration {
// Start the postgres container
PostgresContainerHelper.start()
// Create a minimal config JSON and update it with container details
val configJson =
"""
{
"host": "replace_me_host",
"port": "replace_me_port",
"database": "replace_me_database",
"schema": "public",
"username": "replace_me_username",
"password": "replace_me_password",
"ssl": false
}
"""
val updatedConfig = PostgresConfigUpdater().update(configJson)
val spec = Jsons.readValue(updatedConfig, PostgresSpecificationOss::class.java)
return PostgresConfigurationFactory().makeWithoutExceptionHandling(spec)
}
}

View File

@@ -0,0 +1,36 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.postgres.component
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.component.TableOperationsFixtures
import io.airbyte.cdk.load.component.TableSchema
object PostgresComponentTestFixtures {
// PostgreSQL uses lowercase column names by default (no transformation needed)
val testMapping = TableOperationsFixtures.TEST_MAPPING
val idAndTestMapping = TableOperationsFixtures.ID_AND_TEST_MAPPING
val idTestWithCdcMapping = TableOperationsFixtures.ID_TEST_WITH_CDC_MAPPING
val allTypesTableSchema =
TableSchema(
mapOf(
"string" to ColumnType("varchar", true),
"boolean" to ColumnType("boolean", true),
"integer" to ColumnType("bigint", true),
"number" to ColumnType("decimal", true),
"date" to ColumnType("date", true),
"timestamp_tz" to ColumnType("timestamp with time zone", true),
"timestamp_ntz" to ColumnType("timestamp", true),
"time_tz" to ColumnType("time with time zone", true),
"time_ntz" to ColumnType("time", true),
"array" to ColumnType("jsonb", true),
"object" to ColumnType("jsonb", true),
"unknown" to ColumnType("jsonb", true),
)
)
val allTypesColumnNameMapping = TableOperationsFixtures.ALL_TYPES_MAPPING
}

View File

@@ -0,0 +1,92 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.postgres.component
import io.airbyte.cdk.load.component.TableOperationsFixtures
import io.airbyte.cdk.load.component.TableOperationsSuite
import io.airbyte.cdk.load.schema.TableSchemaFactory
import io.airbyte.integrations.destination.postgres.client.PostgresAirbyteClient
import io.airbyte.integrations.destination.postgres.component.PostgresComponentTestFixtures.idTestWithCdcMapping
import io.airbyte.integrations.destination.postgres.component.PostgresComponentTestFixtures.testMapping
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import jakarta.inject.Inject
import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Test
@MicronautTest(environments = ["component"])
class PostgresTableOperationsTest(
override val client: PostgresAirbyteClient,
override val testClient: PostgresTestTableOperationsClient,
) : TableOperationsSuite {
@Inject override lateinit var schemaFactory: TableSchemaFactory
@Test
override fun `connect to database`() {
super.`connect to database`()
}
@Test
override fun `create and drop namespaces`() {
super.`create and drop namespaces`()
}
@Test
override fun `create and drop tables`() {
super.`create and drop tables`()
}
@Test
override fun `insert records`() {
super.`insert records`(
inputRecords = TableOperationsFixtures.SINGLE_TEST_RECORD_INPUT,
expectedRecords = TableOperationsFixtures.SINGLE_TEST_RECORD_EXPECTED,
columnNameMapping = testMapping,
)
}
@Test
override fun `count table rows`() {
super.`count table rows`(columnNameMapping = testMapping)
}
@Test
override fun `overwrite tables`() {
super.`overwrite tables`(
sourceInputRecords = TableOperationsFixtures.OVERWRITE_SOURCE_RECORDS,
targetInputRecords = TableOperationsFixtures.OVERWRITE_TARGET_RECORDS,
expectedRecords = TableOperationsFixtures.OVERWRITE_EXPECTED_RECORDS,
columnNameMapping = testMapping,
)
}
@Test
override fun `copy tables`() {
super.`copy tables`(
sourceInputRecords = TableOperationsFixtures.OVERWRITE_SOURCE_RECORDS,
targetInputRecords = TableOperationsFixtures.OVERWRITE_TARGET_RECORDS,
expectedRecords = TableOperationsFixtures.COPY_EXPECTED_RECORDS,
columnNameMapping = testMapping,
)
}
@Test
override fun `get generation id`() {
super.`get generation id`(columnNameMapping = testMapping)
}
// TODO: Re-enable when CDK TableOperationsSuite is fixed to use ID_AND_TEST_SCHEMA for target
// table instead of TEST_INTEGER_SCHEMA (the Dedupe mode requires the id column as primary key)
@Disabled("CDK TableOperationsSuite bug: target table schema missing 'id' column for Dedupe")
@Test
override fun `upsert tables`() {
super.`upsert tables`(
sourceInputRecords = TableOperationsFixtures.UPSERT_SOURCE_RECORDS,
targetInputRecords = TableOperationsFixtures.UPSERT_TARGET_RECORDS,
expectedRecords = TableOperationsFixtures.UPSERT_EXPECTED_RECORDS,
columnNameMapping = idTestWithCdcMapping,
)
}
}

View File

@@ -0,0 +1,111 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.postgres.component
import io.airbyte.cdk.load.command.ImportType
import io.airbyte.cdk.load.component.TableSchemaEvolutionFixtures
import io.airbyte.cdk.load.component.TableSchemaEvolutionSuite
import io.airbyte.cdk.load.schema.TableSchemaFactory
import io.airbyte.integrations.destination.postgres.client.PostgresAirbyteClient
import io.airbyte.integrations.destination.postgres.component.PostgresComponentTestFixtures.allTypesColumnNameMapping
import io.airbyte.integrations.destination.postgres.component.PostgresComponentTestFixtures.allTypesTableSchema
import io.airbyte.integrations.destination.postgres.component.PostgresComponentTestFixtures.idAndTestMapping
import io.airbyte.integrations.destination.postgres.component.PostgresComponentTestFixtures.testMapping
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import org.junit.jupiter.api.Test
@MicronautTest(environments = ["component"], resolveParameters = false)
class PostgresTableSchemaEvolutionTest(
override val client: PostgresAirbyteClient,
override val opsClient: PostgresAirbyteClient,
override val testClient: PostgresTestTableOperationsClient,
override val schemaFactory: TableSchemaFactory,
) : TableSchemaEvolutionSuite {
@Test
fun `discover recognizes all data types`() {
super.`discover recognizes all data types`(allTypesTableSchema, allTypesColumnNameMapping)
}
@Test
fun `computeSchema handles all data types`() {
super.`computeSchema handles all data types`(allTypesTableSchema, allTypesColumnNameMapping)
}
@Test
override fun `noop diff`() {
super.`noop diff`(testMapping)
}
@Test
override fun `changeset is correct when adding a column`() {
super.`changeset is correct when adding a column`(testMapping, idAndTestMapping)
}
@Test
override fun `changeset is correct when dropping a column`() {
super.`changeset is correct when dropping a column`(idAndTestMapping, testMapping)
}
@Test
override fun `changeset is correct when changing a column's type`() {
super.`changeset is correct when changing a column's type`(testMapping)
}
@Test
override fun `apply changeset - handle sync mode append`() {
super.`apply changeset - handle sync mode append`()
}
@Test
override fun `apply changeset - handle changing sync mode from append to dedup`() {
super.`apply changeset - handle changing sync mode from append to dedup`()
}
@Test
override fun `apply changeset - handle changing sync mode from dedup to append`() {
super.`apply changeset - handle changing sync mode from dedup to append`()
}
@Test
override fun `apply changeset - handle sync mode dedup`() {
super.`apply changeset - handle sync mode dedup`()
}
override fun `apply changeset`(
initialStreamImportType: ImportType,
modifiedStreamImportType: ImportType,
) {
super.`apply changeset`(
initialColumnNameMapping =
TableSchemaEvolutionFixtures.APPLY_CHANGESET_INITIAL_COLUMN_MAPPING,
modifiedColumnNameMapping =
TableSchemaEvolutionFixtures.APPLY_CHANGESET_MODIFIED_COLUMN_MAPPING,
TableSchemaEvolutionFixtures.APPLY_CHANGESET_EXPECTED_EXTRACTED_AT,
initialStreamImportType,
modifiedStreamImportType,
)
}
@Test
override fun `change from string type to unknown type`() {
super.`change from string type to unknown type`(
idAndTestMapping,
idAndTestMapping,
TableSchemaEvolutionFixtures.STRING_TO_UNKNOWN_TYPE_INPUT_RECORDS,
TableSchemaEvolutionFixtures.STRING_TO_UNKNOWN_TYPE_EXPECTED_RECORDS,
)
}
@Test
override fun `change from unknown type to string type`() {
super.`change from unknown type to string type`(
idAndTestMapping,
idAndTestMapping,
TableSchemaEvolutionFixtures.UNKNOWN_TO_STRING_TYPE_INPUT_RECORDS,
TableSchemaEvolutionFixtures.UNKNOWN_TO_STRING_TYPE_EXPECTED_RECORDS,
)
}
}

View File

@@ -0,0 +1,257 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.postgres.component
import io.airbyte.cdk.load.component.TestTableOperationsClient
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.util.Jsons
import io.airbyte.integrations.destination.postgres.client.PostgresAirbyteClient
import io.micronaut.context.annotation.Requires
import jakarta.inject.Singleton
import java.time.OffsetDateTime
import java.time.ZoneOffset
import java.time.format.DateTimeFormatter
import javax.sql.DataSource
@Requires(env = ["component"])
@Singleton
class PostgresTestTableOperationsClient(
private val dataSource: DataSource,
private val client: PostgresAirbyteClient,
) : TestTableOperationsClient {
override suspend fun ping() {
dataSource.connection.use { connection ->
connection.createStatement().use { statement -> statement.executeQuery("SELECT 1") }
}
}
override suspend fun dropNamespace(namespace: String) {
dataSource.connection.use { connection ->
connection.createStatement().use { statement ->
statement.execute("DROP SCHEMA IF EXISTS \"$namespace\" CASCADE")
}
}
}
override suspend fun insertRecords(table: TableName, records: List<Map<String, AirbyteValue>>) {
if (records.isEmpty()) return
// Get column types from database to handle jsonb columns properly
val columnTypes = getColumnTypes(table)
// Get all unique columns from ALL records to handle sparse data (e.g., CDC deletion column)
val columns = records.flatMap { it.keys }.distinct().toList()
val columnNames = columns.joinToString(", ") { "\"$it\"" }
val placeholders = columns.indices.joinToString(", ") { "?" }
val sql =
"""
INSERT INTO "${table.namespace}"."${table.name}" ($columnNames)
VALUES ($placeholders)
"""
dataSource.connection.use { connection ->
connection.prepareStatement(sql).use { statement ->
for (record in records) {
columns.forEachIndexed { index, column ->
val value = record[column]
val columnType = columnTypes[column]
setParameterValue(statement, index + 1, value, columnType)
}
statement.addBatch()
}
statement.executeBatch()
}
}
}
private fun getColumnTypes(table: TableName): Map<String, String> {
val columnTypes = mutableMapOf<String, String>()
dataSource.connection.use { connection ->
connection.createStatement().use { statement ->
statement
.executeQuery(
"""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = '${table.namespace}'
AND table_name = '${table.name}'
"""
)
.use { resultSet ->
while (resultSet.next()) {
columnTypes[resultSet.getString("column_name")] =
resultSet.getString("data_type")
}
}
}
}
return columnTypes
}
private fun setParameterValue(
statement: java.sql.PreparedStatement,
index: Int,
value: AirbyteValue?,
columnType: String?
) {
// If column is jsonb, serialize any value as JSON
if (columnType == "jsonb") {
if (value == null || value is io.airbyte.cdk.load.data.NullValue) {
statement.setNull(index, java.sql.Types.OTHER)
} else {
val pgObject = org.postgresql.util.PGobject()
pgObject.type = "jsonb"
pgObject.value = serializeToJson(value)
statement.setObject(index, pgObject)
}
return
}
when (value) {
null,
is io.airbyte.cdk.load.data.NullValue -> statement.setNull(index, java.sql.Types.NULL)
is io.airbyte.cdk.load.data.StringValue -> statement.setString(index, value.value)
is io.airbyte.cdk.load.data.IntegerValue ->
statement.setLong(index, value.value.toLong())
is io.airbyte.cdk.load.data.NumberValue -> statement.setBigDecimal(index, value.value)
is io.airbyte.cdk.load.data.BooleanValue -> statement.setBoolean(index, value.value)
is io.airbyte.cdk.load.data.TimestampWithTimezoneValue -> {
val offsetDateTime = OffsetDateTime.parse(value.value.toString())
statement.setObject(index, offsetDateTime)
}
is io.airbyte.cdk.load.data.TimestampWithoutTimezoneValue -> {
val localDateTime = java.time.LocalDateTime.parse(value.value.toString())
statement.setObject(index, localDateTime)
}
is io.airbyte.cdk.load.data.DateValue -> {
val localDate = java.time.LocalDate.parse(value.value.toString())
statement.setObject(index, localDate)
}
is io.airbyte.cdk.load.data.TimeWithTimezoneValue -> {
statement.setString(index, value.value.toString())
}
is io.airbyte.cdk.load.data.TimeWithoutTimezoneValue -> {
val localTime = java.time.LocalTime.parse(value.value.toString())
statement.setObject(index, localTime)
}
is io.airbyte.cdk.load.data.ObjectValue -> {
val pgObject = org.postgresql.util.PGobject()
pgObject.type = "jsonb"
pgObject.value = Jsons.writeValueAsString(value.values)
statement.setObject(index, pgObject)
}
is io.airbyte.cdk.load.data.ArrayValue -> {
val pgObject = org.postgresql.util.PGobject()
pgObject.type = "jsonb"
pgObject.value = Jsons.writeValueAsString(value.values)
statement.setObject(index, pgObject)
}
else -> {
// For unknown types, try to serialize as string
statement.setString(index, value.toString())
}
}
}
private fun serializeToJson(value: AirbyteValue): String {
return when (value) {
is io.airbyte.cdk.load.data.StringValue -> Jsons.writeValueAsString(value.value)
is io.airbyte.cdk.load.data.IntegerValue -> value.value.toString()
is io.airbyte.cdk.load.data.NumberValue -> value.value.toString()
is io.airbyte.cdk.load.data.BooleanValue -> value.value.toString()
is io.airbyte.cdk.load.data.ObjectValue -> Jsons.writeValueAsString(value.values)
is io.airbyte.cdk.load.data.ArrayValue -> Jsons.writeValueAsString(value.values)
is io.airbyte.cdk.load.data.NullValue -> "null"
else -> Jsons.writeValueAsString(value.toString())
}
}
override suspend fun readTable(table: TableName): List<Map<String, Any>> {
dataSource.connection.use { connection ->
connection.createStatement().use { statement ->
statement
.executeQuery("""SELECT * FROM "${table.namespace}"."${table.name}"""")
.use { resultSet ->
val metaData = resultSet.metaData
val columnCount = metaData.columnCount
val result = mutableListOf<Map<String, Any>>()
while (resultSet.next()) {
val row = mutableMapOf<String, Any>()
for (i in 1..columnCount) {
val columnName = metaData.getColumnName(i)
val columnType = metaData.getColumnTypeName(i)
when (columnType.lowercase()) {
"timestamptz" -> {
val value =
resultSet.getObject(i, OffsetDateTime::class.java)
if (value != null) {
val formattedTimestamp =
DateTimeFormatter.ISO_OFFSET_DATE_TIME.format(
value.withOffsetSameInstant(ZoneOffset.UTC)
)
row[columnName] = formattedTimestamp
}
}
"timestamp" -> {
val value = resultSet.getTimestamp(i)
if (value != null) {
val localDateTime = value.toLocalDateTime()
row[columnName] =
DateTimeFormatter.ISO_LOCAL_DATE_TIME.format(
localDateTime
)
}
}
"jsonb",
"json" -> {
val stringValue: String? = resultSet.getString(i)
if (stringValue != null) {
val parsedValue =
Jsons.readValue(stringValue, Any::class.java)
val actualValue =
when (parsedValue) {
is Int -> parsedValue.toLong()
else -> parsedValue
}
row[columnName] = actualValue
}
}
else -> {
val value = resultSet.getObject(i)
if (value != null) {
// For varchar columns that may contain JSON (from
// schema evolution),
// normalize the JSON to compact format for comparison
if (
value is String &&
(value.startsWith("{") || value.startsWith("["))
) {
try {
val parsed =
Jsons.readValue(value, Any::class.java)
row[columnName] =
Jsons.writeValueAsString(parsed)
} catch (_: Exception) {
row[columnName] = value
}
} else {
row[columnName] = value
}
}
}
}
}
result.add(row)
}
return result
}
}
}
}
}

View File

@@ -267,7 +267,7 @@ class PostgresRawDataDumper(
.lowercase() .lowercase()
.toPostgresCompatibleName() .toPostgresCompatibleName()
val fullyQualifiedTableName = "$rawNamespace.$rawName" val fullyQualifiedTableName = "\"$rawNamespace\".\"$rawName\""
// Check if table exists first // Check if table exists first
val tableExistsQuery = val tableExistsQuery =
@@ -302,6 +302,26 @@ class PostgresRawDataDumper(
false false
} }
// Build the column name mapping from original names to transformed names
// We use the stream schema to get the original field names, then transform them
// using the postgres name transformation logic
val finalToInputColumnNames = mutableMapOf<String, String>()
if (stream.schema is ObjectType) {
val objectSchema = stream.schema as ObjectType
for (fieldName in objectSchema.properties.keys) {
val transformedName = fieldName.toPostgresCompatibleName()
// Map transformed name back to original name
finalToInputColumnNames[transformedName] = fieldName
}
}
// Also check if inputToFinalColumnNames mapping is available
val inputToFinalColumnNames =
stream.tableSchema.columnSchema.inputToFinalColumnNames
// Add entries from the existing mapping (in case it was populated)
for ((input, final) in inputToFinalColumnNames) {
finalToInputColumnNames[final] = input
}
while (resultSet.next()) { while (resultSet.next()) {
val rawData = val rawData =
if (hasDataColumn) { if (hasDataColumn) {
@@ -313,8 +333,22 @@ class PostgresRawDataDumper(
else -> dataObject?.toString() ?: "{}" else -> dataObject?.toString() ?: "{}"
} }
// Parse JSON to AirbyteValue, then coerce it to match the schema // Parse JSON to AirbyteValue, then map column names back to originals
dataJson?.deserializeToNode()?.toAirbyteValue() ?: NullValue val parsedValue =
dataJson?.deserializeToNode()?.toAirbyteValue() ?: NullValue
// If the parsed value is an ObjectValue, map the column names back
if (parsedValue is ObjectValue) {
val mappedProperties = linkedMapOf<String, AirbyteValue>()
for ((key, value) in parsedValue.values) {
// Map final column name back to input column name if mapping
// exists
val originalKey = finalToInputColumnNames[key] ?: key
mappedProperties[originalKey] = value
}
ObjectValue(mappedProperties)
} else {
parsedValue
}
} else { } else {
// Typed table mode: read from individual columns and reconstruct the // Typed table mode: read from individual columns and reconstruct the
// object // object
@@ -333,10 +367,19 @@ class PostgresRawDataDumper(
for ((fieldName, fieldType) in objectSchema.properties) { for ((fieldName, fieldType) in objectSchema.properties) {
try { try {
// Map input field name to the transformed final column name
// First check the inputToFinalColumnNames mapping, then
// fall
// back to applying postgres transformation directly
val transformedColumnName =
inputToFinalColumnNames[fieldName]
?: fieldName.toPostgresCompatibleName()
// Try to find the actual column name (case-insensitive // Try to find the actual column name (case-insensitive
// lookup) // lookup)
val actualColumnName = val actualColumnName =
columnMap[fieldName.lowercase()] ?: fieldName columnMap[transformedColumnName.lowercase()]
?: transformedColumnName
val columnValue = resultSet.getObject(actualColumnName) val columnValue = resultSet.getObject(actualColumnName)
properties[fieldName] = properties[fieldName] =
when (columnValue) { when (columnValue) {

View File

@@ -1,3 +1,3 @@
testExecutionConcurrency=-1 testExecutionConcurrency=-1
cdkVersion=0.1.82 cdkVersion=0.1.91
JunitMethodExecutionTimeout=10m JunitMethodExecutionTimeout=10m

View File

@@ -6,7 +6,7 @@ data:
connectorSubtype: database connectorSubtype: database
connectorType: destination connectorType: destination
definitionId: 424892c4-daac-4491-b35d-c6688ba547ba definitionId: 424892c4-daac-4491-b35d-c6688ba547ba
dockerImageTag: 4.0.31 dockerImageTag: 4.0.32-rc.1
dockerRepository: airbyte/destination-snowflake dockerRepository: airbyte/destination-snowflake
documentationUrl: https://docs.airbyte.com/integrations/destinations/snowflake documentationUrl: https://docs.airbyte.com/integrations/destinations/snowflake
githubIssueLabel: destination-snowflake githubIssueLabel: destination-snowflake
@@ -31,6 +31,8 @@ data:
enabled: true enabled: true
releaseStage: generally_available releaseStage: generally_available
releases: releases:
rolloutConfiguration:
enableProgressiveRollout: true
breakingChanges: breakingChanges:
2.0.0: 2.0.0:
message: Remove GCS/S3 loading method support. message: Remove GCS/S3 loading method support.

View File

@@ -11,15 +11,18 @@ import io.airbyte.cdk.load.check.CheckOperationV2
import io.airbyte.cdk.load.check.DestinationCheckerV2 import io.airbyte.cdk.load.check.DestinationCheckerV2
import io.airbyte.cdk.load.config.DataChannelMedium import io.airbyte.cdk.load.config.DataChannelMedium
import io.airbyte.cdk.load.dataflow.config.AggregatePublishingConfig import io.airbyte.cdk.load.dataflow.config.AggregatePublishingConfig
import io.airbyte.cdk.load.orchestration.db.DefaultTempTableNameGenerator import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.cdk.output.OutputConsumer import io.airbyte.cdk.output.OutputConsumer
import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory
import io.airbyte.integrations.destination.snowflake.spec.UsernamePasswordAuthConfiguration import io.airbyte.integrations.destination.snowflake.spec.UsernamePasswordAuthConfiguration
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRawRecordFormatter
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter
import io.micronaut.context.annotation.Factory import io.micronaut.context.annotation.Factory
import io.micronaut.context.annotation.Primary import io.micronaut.context.annotation.Primary
import io.micronaut.context.annotation.Requires import io.micronaut.context.annotation.Requires
@@ -204,6 +207,17 @@ class SnowflakeBeanFactory {
outputConsumer: OutputConsumer, outputConsumer: OutputConsumer,
) = CheckOperationV2(destinationChecker, outputConsumer) ) = CheckOperationV2(destinationChecker, outputConsumer)
@Singleton
fun snowflakeRecordFormatter(
snowflakeConfiguration: SnowflakeConfiguration
): SnowflakeRecordFormatter {
return if (snowflakeConfiguration.legacyRawTablesOnly) {
SnowflakeRawRecordFormatter()
} else {
SnowflakeSchemaRecordFormatter()
}
}
@Singleton @Singleton
fun aggregatePublishingConfig(dataChannelMedium: DataChannelMedium): AggregatePublishingConfig { fun aggregatePublishingConfig(dataChannelMedium: DataChannelMedium): AggregatePublishingConfig {
// NOT speed mode // NOT speed mode

View File

@@ -13,13 +13,17 @@ import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.ObjectType import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.StringType import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.message.Meta import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.schema.model.TableNames
import io.airbyte.cdk.load.table.ColumnNameMapping import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TableName
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter
import jakarta.inject.Singleton import jakarta.inject.Singleton
import java.time.OffsetDateTime import java.time.OffsetDateTime
import java.util.UUID import java.util.UUID
@@ -31,7 +35,7 @@ internal const val CHECK_COLUMN_NAME = "test_key"
class SnowflakeChecker( class SnowflakeChecker(
private val snowflakeAirbyteClient: SnowflakeAirbyteClient, private val snowflakeAirbyteClient: SnowflakeAirbyteClient,
private val snowflakeConfiguration: SnowflakeConfiguration, private val snowflakeConfiguration: SnowflakeConfiguration,
private val snowflakeColumnUtils: SnowflakeColumnUtils, private val columnManager: SnowflakeColumnManager,
) : DestinationCheckerV2 { ) : DestinationCheckerV2 {
override fun check() { override fun check() {
@@ -46,11 +50,40 @@ class SnowflakeChecker(
Meta.AirbyteMetaFields.GENERATION_ID.fieldName to AirbyteValue.from(0), Meta.AirbyteMetaFields.GENERATION_ID.fieldName to AirbyteValue.from(0),
CHECK_COLUMN_NAME.toSnowflakeCompatibleName() to AirbyteValue.from("test-value") CHECK_COLUMN_NAME.toSnowflakeCompatibleName() to AirbyteValue.from("test-value")
) )
val outputSchema = snowflakeConfiguration.schema.toSnowflakeCompatibleName() val outputSchema =
if (snowflakeConfiguration.legacyRawTablesOnly) {
snowflakeConfiguration.schema
} else {
snowflakeConfiguration.schema.toSnowflakeCompatibleName()
}
val tableName = val tableName =
"_airbyte_connection_test_${ "_airbyte_connection_test_${
UUID.randomUUID().toString().replace("-".toRegex(), "")}".toSnowflakeCompatibleName() UUID.randomUUID().toString().replace("-".toRegex(), "")}".toSnowflakeCompatibleName()
val qualifiedTableName = TableName(namespace = outputSchema, name = tableName) val qualifiedTableName = TableName(namespace = outputSchema, name = tableName)
val tableSchema =
StreamTableSchema(
tableNames =
TableNames(
finalTableName = qualifiedTableName,
tempTableName = qualifiedTableName
),
columnSchema =
ColumnSchema(
inputToFinalColumnNames =
mapOf(
CHECK_COLUMN_NAME to CHECK_COLUMN_NAME.toSnowflakeCompatibleName()
),
finalSchema =
mapOf(
CHECK_COLUMN_NAME.toSnowflakeCompatibleName() to
io.airbyte.cdk.load.component.ColumnType("VARCHAR", false)
),
inputSchema =
mapOf(CHECK_COLUMN_NAME to FieldType(StringType, nullable = false))
),
importType = Append
)
val destinationStream = val destinationStream =
DestinationStream( DestinationStream(
unmappedNamespace = outputSchema, unmappedNamespace = outputSchema,
@@ -63,7 +96,8 @@ class SnowflakeChecker(
generationId = 0L, generationId = 0L,
minimumGenerationId = 0L, minimumGenerationId = 0L,
syncId = 0L, syncId = 0L,
namespaceMapper = NamespaceMapper() namespaceMapper = NamespaceMapper(),
tableSchema = tableSchema
) )
runBlocking { runBlocking {
try { try {
@@ -75,14 +109,14 @@ class SnowflakeChecker(
replace = true, replace = true,
) )
val columns = snowflakeAirbyteClient.describeTable(qualifiedTableName)
val snowflakeInsertBuffer = val snowflakeInsertBuffer =
SnowflakeInsertBuffer( SnowflakeInsertBuffer(
tableName = qualifiedTableName, tableName = qualifiedTableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient, snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration, snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils, columnSchema = tableSchema.columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = SnowflakeSchemaRecordFormatter(),
) )
snowflakeInsertBuffer.accumulate(data) snowflakeInsertBuffer.accumulate(data)

View File

@@ -13,18 +13,16 @@ import io.airbyte.cdk.load.component.TableColumns
import io.airbyte.cdk.load.component.TableOperationsClient import io.airbyte.cdk.load.component.TableOperationsClient
import io.airbyte.cdk.load.component.TableSchema import io.airbyte.cdk.load.component.TableSchema
import io.airbyte.cdk.load.component.TableSchemaEvolutionClient import io.airbyte.cdk.load.component.TableSchemaEvolutionClient
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.ColumnNameMapping import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.util.deserializeToNode import io.airbyte.cdk.load.util.deserializeToNode
import io.airbyte.integrations.destination.snowflake.db.ColumnDefinition import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS
import io.airbyte.integrations.destination.snowflake.sql.NOT_NULL
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
import io.airbyte.integrations.destination.snowflake.sql.andLog import io.airbyte.integrations.destination.snowflake.sql.andLog
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
import io.github.oshai.kotlinlogging.KotlinLogging import io.github.oshai.kotlinlogging.KotlinLogging
import jakarta.inject.Singleton import jakarta.inject.Singleton
import java.sql.ResultSet import java.sql.ResultSet
@@ -41,13 +39,10 @@ private val log = KotlinLogging.logger {}
class SnowflakeAirbyteClient( class SnowflakeAirbyteClient(
private val dataSource: DataSource, private val dataSource: DataSource,
private val sqlGenerator: SnowflakeDirectLoadSqlGenerator, private val sqlGenerator: SnowflakeDirectLoadSqlGenerator,
private val snowflakeColumnUtils: SnowflakeColumnUtils,
private val snowflakeConfiguration: SnowflakeConfiguration, private val snowflakeConfiguration: SnowflakeConfiguration,
private val columnManager: SnowflakeColumnManager,
) : TableOperationsClient, TableSchemaEvolutionClient { ) : TableOperationsClient, TableSchemaEvolutionClient {
private val airbyteColumnNames =
snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet()
override suspend fun countTable(tableName: TableName): Long? = override suspend fun countTable(tableName: TableName): Long? =
try { try {
dataSource.connection.use { connection -> dataSource.connection.use { connection ->
@@ -126,7 +121,7 @@ class SnowflakeAirbyteClient(
columnNameMapping: ColumnNameMapping, columnNameMapping: ColumnNameMapping,
replace: Boolean replace: Boolean
) { ) {
execute(sqlGenerator.createTable(stream, tableName, columnNameMapping, replace)) execute(sqlGenerator.createTable(tableName, stream.tableSchema, replace))
execute(sqlGenerator.createSnowflakeStage(tableName)) execute(sqlGenerator.createSnowflakeStage(tableName))
} }
@@ -163,7 +158,15 @@ class SnowflakeAirbyteClient(
sourceTableName: TableName, sourceTableName: TableName,
targetTableName: TableName targetTableName: TableName
) { ) {
execute(sqlGenerator.copyTable(columnNameMapping, sourceTableName, targetTableName)) // Get all column names from the mapping (both meta columns and user columns)
val columnNames = buildSet {
// Add Airbyte meta columns (using uppercase constants)
addAll(columnManager.getMetaColumnNames())
// Add user columns from mapping
addAll(columnNameMapping.values)
}
execute(sqlGenerator.copyTable(columnNames, sourceTableName, targetTableName))
} }
override suspend fun upsertTable( override suspend fun upsertTable(
@@ -172,9 +175,7 @@ class SnowflakeAirbyteClient(
sourceTableName: TableName, sourceTableName: TableName,
targetTableName: TableName targetTableName: TableName
) { ) {
execute( execute(sqlGenerator.upsertTable(stream.tableSchema, sourceTableName, targetTableName))
sqlGenerator.upsertTable(stream, columnNameMapping, sourceTableName, targetTableName)
)
} }
override suspend fun dropTable(tableName: TableName) { override suspend fun dropTable(tableName: TableName) {
@@ -206,7 +207,7 @@ class SnowflakeAirbyteClient(
stream: DestinationStream, stream: DestinationStream,
columnNameMapping: ColumnNameMapping columnNameMapping: ColumnNameMapping
): TableSchema { ): TableSchema {
return TableSchema(getColumnsFromStream(stream, columnNameMapping)) return TableSchema(stream.tableSchema.columnSchema.finalSchema)
} }
override suspend fun applyChangeset( override suspend fun applyChangeset(
@@ -253,7 +254,7 @@ class SnowflakeAirbyteClient(
val columnName = escapeJsonIdentifier(rs.getString("name")) val columnName = escapeJsonIdentifier(rs.getString("name"))
// Filter out airbyte columns // Filter out airbyte columns
if (airbyteColumnNames.contains(columnName)) { if (columnManager.getMetaColumnNames().contains(columnName)) {
continue continue
} }
val dataType = rs.getString("type").takeWhile { char -> char != '(' } val dataType = rs.getString("type").takeWhile { char -> char != '(' }
@@ -271,49 +272,6 @@ class SnowflakeAirbyteClient(
} }
} }
internal fun getColumnsFromStream(
stream: DestinationStream,
columnNameMapping: ColumnNameMapping
): Map<String, ColumnType> =
snowflakeColumnUtils
.columnsAndTypes(stream.schema.asColumns(), columnNameMapping)
.filter { column -> column.columnName !in airbyteColumnNames }
.associate { column ->
// columnsAndTypes returns types as either `FOO` or `FOO NOT NULL`.
// so check for that suffix.
val nullable = !column.columnType.endsWith(NOT_NULL)
val type =
column.columnType
.takeWhile { char ->
// This is to remove any precision parts of the dialect type
char != '('
}
.removeSuffix(NOT_NULL)
.trim()
column.columnName to ColumnType(type, nullable)
}
internal fun generateSchemaChanges(
columnsInDb: Set<ColumnDefinition>,
columnsInStream: Set<ColumnDefinition>
): Triple<Set<ColumnDefinition>, Set<ColumnDefinition>, Set<ColumnDefinition>> {
val addedColumns =
columnsInStream.filter { it.name !in columnsInDb.map { col -> col.name } }.toSet()
val deletedColumns =
columnsInDb.filter { it.name !in columnsInStream.map { col -> col.name } }.toSet()
val commonColumns =
columnsInStream.filter { it.name in columnsInDb.map { col -> col.name } }.toSet()
val modifiedColumns =
commonColumns
.filter {
val dbType = columnsInDb.find { column -> it.name == column.name }?.type
it.type != dbType
}
.toSet()
return Triple(addedColumns, deletedColumns, modifiedColumns)
}
override suspend fun getGenerationId(tableName: TableName): Long = override suspend fun getGenerationId(tableName: TableName): Long =
try { try {
dataSource.connection.use { connection -> dataSource.connection.use { connection ->
@@ -326,7 +284,7 @@ class SnowflakeAirbyteClient(
* format. In order to make sure these strings will match any column names * format. In order to make sure these strings will match any column names
* that we have formatted in-memory, re-apply the escaping. * that we have formatted in-memory, re-apply the escaping.
*/ */
resultSet.getLong(snowflakeColumnUtils.getGenerationIdColumnName()) resultSet.getLong(columnManager.getGenerationIdColumnName())
} else { } else {
log.warn { log.warn {
"No generation ID found for table ${tableName.toPrettyString()}, returning 0" "No generation ID found for table ${tableName.toPrettyString()}, returning 0"
@@ -351,8 +309,8 @@ class SnowflakeAirbyteClient(
execute(sqlGenerator.putInStage(tableName, tempFilePath)) execute(sqlGenerator.putInStage(tableName, tempFilePath))
} }
fun copyFromStage(tableName: TableName, filename: String) { fun copyFromStage(tableName: TableName, filename: String, columnNames: List<String>) {
execute(sqlGenerator.copyFromStage(tableName, filename)) execute(sqlGenerator.copyFromStage(tableName, filename, columnNames))
} }
fun describeTable(tableName: TableName): LinkedHashMap<String, String> = fun describeTable(tableName: TableName): LinkedHashMap<String, String> =

View File

@@ -4,47 +4,41 @@
package io.airbyte.integrations.destination.snowflake.dataflow package io.airbyte.integrations.destination.snowflake.dataflow
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.dataflow.aggregate.Aggregate import io.airbyte.cdk.load.dataflow.aggregate.Aggregate
import io.airbyte.cdk.load.dataflow.aggregate.AggregateFactory import io.airbyte.cdk.load.dataflow.aggregate.AggregateFactory
import io.airbyte.cdk.load.dataflow.aggregate.StoreKey import io.airbyte.cdk.load.dataflow.aggregate.StoreKey
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.write.StreamStateStore import io.airbyte.cdk.load.write.StreamStateStore
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
import io.micronaut.cache.annotation.CacheConfig import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
import io.micronaut.cache.annotation.Cacheable
import jakarta.inject.Singleton import jakarta.inject.Singleton
@Singleton @Singleton
@CacheConfig("table-columns") class SnowflakeAggregateFactory(
// class has to be open to make the cache stuff work
open class SnowflakeAggregateFactory(
private val snowflakeClient: SnowflakeAirbyteClient, private val snowflakeClient: SnowflakeAirbyteClient,
private val streamStateStore: StreamStateStore<DirectLoadTableExecutionConfig>, private val streamStateStore: StreamStateStore<DirectLoadTableExecutionConfig>,
private val snowflakeConfiguration: SnowflakeConfiguration, private val snowflakeConfiguration: SnowflakeConfiguration,
private val snowflakeColumnUtils: SnowflakeColumnUtils, private val catalog: DestinationCatalog,
private val columnManager: SnowflakeColumnManager,
private val snowflakeRecordFormatter: SnowflakeRecordFormatter,
) : AggregateFactory { ) : AggregateFactory {
override fun create(key: StoreKey): Aggregate { override fun create(key: StoreKey): Aggregate {
val stream = catalog.getStream(key)
val tableName = streamStateStore.get(key)!!.tableName val tableName = streamStateStore.get(key)!!.tableName
val buffer = val buffer =
SnowflakeInsertBuffer( SnowflakeInsertBuffer(
tableName = tableName, tableName = tableName,
columns = getTableColumns(tableName),
snowflakeClient = snowflakeClient, snowflakeClient = snowflakeClient,
snowflakeConfiguration = snowflakeConfiguration, snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils, columnSchema = stream.tableSchema.columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
) )
return SnowflakeAggregate(buffer = buffer) return SnowflakeAggregate(buffer = buffer)
} }
// We assume that a table isn't getting altered _during_ a sync.
// This allows us to only SHOW COLUMNS once per table per sync,
// rather than refetching it on every aggregate.
@Cacheable
// function has to be open to make caching work
internal open fun getTableColumns(tableName: TableName) =
snowflakeClient.describeTable(tableName)
} }

View File

@@ -1,13 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
/**
* Jdbc destination column definition representation
*
* @param name
* @param type
*/
data class ColumnDefinition(val name: String, val type: String)

View File

@@ -4,17 +4,17 @@
package io.airbyte.integrations.destination.snowflake.db package io.airbyte.integrations.destination.snowflake.db
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.component.TableOperationsClient import io.airbyte.cdk.load.component.TableOperationsClient
import io.airbyte.cdk.load.orchestration.db.BaseDirectLoadInitialStatusGatherer import io.airbyte.cdk.load.table.BaseDirectLoadInitialStatusGatherer
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator
import jakarta.inject.Singleton import jakarta.inject.Singleton
@Singleton @Singleton
class SnowflakeDirectLoadDatabaseInitialStatusGatherer( class SnowflakeDirectLoadDatabaseInitialStatusGatherer(
tableOperationsClient: TableOperationsClient, tableOperationsClient: TableOperationsClient,
tempTableNameGenerator: TempTableNameGenerator, catalog: DestinationCatalog,
) : ) :
BaseDirectLoadInitialStatusGatherer( BaseDirectLoadInitialStatusGatherer(
tableOperationsClient, tableOperationsClient,
tempTableNameGenerator, catalog,
) )

View File

@@ -1,95 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
import io.airbyte.cdk.ConfigErrorException
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.orchestration.db.ColumnNameGenerator
import io.airbyte.cdk.load.orchestration.db.FinalTableNameGenerator
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TypingDedupingUtil
import io.airbyte.cdk.load.table.TableName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.QUOTE
import jakarta.inject.Singleton
@Singleton
class SnowflakeFinalTableNameGenerator(private val config: SnowflakeConfiguration) :
FinalTableNameGenerator {
override fun getTableName(streamDescriptor: DestinationStream.Descriptor): TableName {
val namespace = streamDescriptor.namespace ?: config.schema
return if (!config.legacyRawTablesOnly) {
TableName(
namespace = namespace.toSnowflakeCompatibleName(),
name = streamDescriptor.name.toSnowflakeCompatibleName(),
)
} else {
TableName(
namespace = config.internalTableSchema,
name =
TypingDedupingUtil.concatenateRawTableName(
namespace = escapeJsonIdentifier(namespace),
name = escapeJsonIdentifier(streamDescriptor.name),
),
)
}
}
}
@Singleton
class SnowflakeColumnNameGenerator(private val config: SnowflakeConfiguration) :
ColumnNameGenerator {
override fun getColumnName(column: String): ColumnNameGenerator.ColumnName {
return if (!config.legacyRawTablesOnly) {
ColumnNameGenerator.ColumnName(
column.toSnowflakeCompatibleName(),
column.toSnowflakeCompatibleName(),
)
} else {
ColumnNameGenerator.ColumnName(
column,
column,
)
}
}
}
/**
* Escapes double-quotes in a JSON identifier by doubling them. This shit is legacy -- I don't know
* why this would be necessary but no harm in keeping it so I am keeping it.
*
* @return The escaped identifier.
*/
fun escapeJsonIdentifier(identifier: String): String {
// Note that we don't need to escape backslashes here!
// The only special character in an identifier is the double-quote, which needs to be
// doubled.
return identifier.replace(QUOTE, "$QUOTE$QUOTE")
}
/**
* Transforms a string to be compatible with Snowflake table and column names.
*
* @return The transformed string suitable for Snowflake identifiers.
*/
fun String.toSnowflakeCompatibleName(): String {
var identifier = this
// Handle empty strings
if (identifier.isEmpty()) {
throw ConfigErrorException("Empty string is invalid identifier")
}
// Snowflake scripting language does something weird when the `${` bigram shows up in the
// script so replace these with something else.
// For completeness, if we trigger this, also replace closing curly braces with underscores.
if (identifier.contains("\${")) {
identifier = identifier.replace("$", "_").replace("{", "_").replace("}", "_")
}
// Escape double quotes
identifier = escapeJsonIdentifier(identifier)
return identifier.uppercase()
}

View File

@@ -0,0 +1,116 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.schema
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_EXTRACTED_AT
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_GENERATION_ID
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_META
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_RAW_ID
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDataType
import jakarta.inject.Singleton
/**
* Manages column names and ordering for Snowflake tables based on whether legacy raw tables mode is
* enabled.
*
* TODO: We should add meta column munging and raw table support to the CDK, so this extra layer of
* management shouldn't be necessary.
*/
@Singleton
class SnowflakeColumnManager(
private val config: SnowflakeConfiguration,
) {
/**
* Get the list of column names for a table in the order they should appear in the CSV file and
* COPY INTO statement.
*
* Warning: MUST match the order defined in SnowflakeRecordFormatter
*
* @param columnSchema The schema containing column information (ignored in raw mode)
* @return List of column names in the correct order
*/
fun getTableColumnNames(columnSchema: ColumnSchema): List<String> {
return buildList {
addAll(getMetaColumnNames())
addAll(columnSchema.finalSchema.keys)
}
}
/**
* Get the list of Airbyte meta column names. In schema mode, these are uppercase. In raw mode,
* they are lowercase and included loaded_at
*
* @return Set of meta column names
*/
fun getMetaColumnNames(): Set<String> =
if (config.legacyRawTablesOnly) {
Constants.rawModeMetaColNames
} else {
Constants.schemaModeMetaColNames
}
/**
* Get the Airbyte meta columns as a map of column name to ColumnType. This provides both the
* column names and their types for table creation.
*
* @param columnSchema The user column schema (used to check for CDC columns)
* @return Map of meta column names to their types
*/
fun getMetaColumns(): LinkedHashMap<String, ColumnType> {
return if (config.legacyRawTablesOnly) {
Constants.rawModeMetaColumns
} else {
Constants.schemaModeMetaColumns
}
}
fun getGenerationIdColumnName(): String {
return if (config.legacyRawTablesOnly) {
Meta.COLUMN_NAME_AB_GENERATION_ID
} else {
SNOWFLAKE_AB_GENERATION_ID
}
}
object Constants {
val rawModeMetaColumns =
linkedMapOf(
Meta.COLUMN_NAME_AB_RAW_ID to ColumnType(SnowflakeDataType.VARCHAR.typeName, false),
Meta.COLUMN_NAME_AB_EXTRACTED_AT to
ColumnType(
SnowflakeDataType.TIMESTAMP_TZ.typeName,
false,
),
Meta.COLUMN_NAME_AB_META to ColumnType(SnowflakeDataType.VARIANT.typeName, false),
Meta.COLUMN_NAME_AB_GENERATION_ID to
ColumnType(
SnowflakeDataType.NUMBER.typeName,
true,
),
Meta.COLUMN_NAME_AB_LOADED_AT to
ColumnType(
SnowflakeDataType.TIMESTAMP_TZ.typeName,
true,
),
)
val schemaModeMetaColumns =
linkedMapOf(
SNOWFLAKE_AB_RAW_ID to ColumnType(SnowflakeDataType.VARCHAR.typeName, false),
SNOWFLAKE_AB_EXTRACTED_AT to
ColumnType(SnowflakeDataType.TIMESTAMP_TZ.typeName, false),
SNOWFLAKE_AB_META to ColumnType(SnowflakeDataType.VARIANT.typeName, false),
SNOWFLAKE_AB_GENERATION_ID to ColumnType(SnowflakeDataType.NUMBER.typeName, true),
)
val rawModeMetaColNames: Set<String> = rawModeMetaColumns.keys
val schemaModeMetaColNames: Set<String> = schemaModeMetaColumns.keys
}
}

View File

@@ -0,0 +1,34 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.schema
import io.airbyte.cdk.ConfigErrorException
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
/**
* Transforms a string to be compatible with Snowflake table and column names.
*
* @return The transformed string suitable for Snowflake identifiers.
*/
fun String.toSnowflakeCompatibleName(): String {
var identifier = this
// Handle empty strings
if (identifier.isEmpty()) {
throw ConfigErrorException("Empty string is invalid identifier")
}
// Snowflake scripting language does something weird when the `${` bigram shows up in the
// script so replace these with something else.
// For completeness, if we trigger this, also replace closing curly braces with underscores.
if (identifier.contains("\${")) {
identifier = identifier.replace("$", "_").replace("{", "_").replace("}", "_")
}
// Escape double quotes
identifier = escapeJsonIdentifier(identifier)
return identifier.uppercase()
}

View File

@@ -0,0 +1,121 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.schema
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.data.ArrayType
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
import io.airbyte.cdk.load.data.BooleanType
import io.airbyte.cdk.load.data.DateType
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerType
import io.airbyte.cdk.load.data.NumberType
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
import io.airbyte.cdk.load.data.UnionType
import io.airbyte.cdk.load.data.UnknownType
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.TableSchemaMapper
import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.cdk.load.table.TypingDedupingUtil
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDataType
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
import jakarta.inject.Singleton
@Singleton
class SnowflakeTableSchemaMapper(
private val config: SnowflakeConfiguration,
private val tempTableNameGenerator: TempTableNameGenerator,
) : TableSchemaMapper {
override fun toFinalTableName(desc: DestinationStream.Descriptor): TableName {
val namespace = desc.namespace ?: config.schema
return if (!config.legacyRawTablesOnly) {
TableName(
namespace = namespace.toSnowflakeCompatibleName(),
name = desc.name.toSnowflakeCompatibleName(),
)
} else {
TableName(
namespace = config.internalTableSchema,
name =
TypingDedupingUtil.concatenateRawTableName(
namespace = escapeJsonIdentifier(namespace),
name = escapeJsonIdentifier(desc.name),
),
)
}
}
override fun toTempTableName(tableName: TableName): TableName {
return tempTableNameGenerator.generate(tableName)
}
override fun toColumnName(name: String): String {
return if (!config.legacyRawTablesOnly) {
name.toSnowflakeCompatibleName()
} else {
// In legacy mode, column names are not transformed
name
}
}
override fun toColumnType(fieldType: FieldType): ColumnType {
val snowflakeType =
when (fieldType.type) {
// Simple types
BooleanType -> SnowflakeDataType.BOOLEAN.typeName
IntegerType -> SnowflakeDataType.NUMBER.typeName
NumberType -> SnowflakeDataType.FLOAT.typeName
StringType -> SnowflakeDataType.VARCHAR.typeName
// Temporal types
DateType -> SnowflakeDataType.DATE.typeName
TimeTypeWithTimezone -> SnowflakeDataType.VARCHAR.typeName
TimeTypeWithoutTimezone -> SnowflakeDataType.TIME.typeName
TimestampTypeWithTimezone -> SnowflakeDataType.TIMESTAMP_TZ.typeName
TimestampTypeWithoutTimezone -> SnowflakeDataType.TIMESTAMP_NTZ.typeName
// Semistructured types
is ArrayType,
ArrayTypeWithoutSchema -> SnowflakeDataType.ARRAY.typeName
is ObjectType,
ObjectTypeWithEmptySchema,
ObjectTypeWithoutSchema -> SnowflakeDataType.OBJECT.typeName
is UnionType -> SnowflakeDataType.VARIANT.typeName
is UnknownType -> SnowflakeDataType.VARIANT.typeName
}
return ColumnType(snowflakeType, fieldType.nullable)
}
override fun toFinalSchema(tableSchema: StreamTableSchema): StreamTableSchema {
if (!config.legacyRawTablesOnly) {
return tableSchema
}
return StreamTableSchema(
tableNames = tableSchema.tableNames,
columnSchema =
tableSchema.columnSchema.copy(
finalSchema =
mapOf(
Meta.COLUMN_NAME_DATA to
ColumnType(SnowflakeDataType.OBJECT.typeName, false)
)
),
importType = tableSchema.importType,
)
}
}

View File

@@ -1,201 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
import com.google.common.annotations.VisibleForTesting
import io.airbyte.cdk.load.data.AirbyteType
import io.airbyte.cdk.load.data.ArrayType
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
import io.airbyte.cdk.load.data.BooleanType
import io.airbyte.cdk.load.data.DateType
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerType
import io.airbyte.cdk.load.data.NumberType
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
import io.airbyte.cdk.load.data.UnionType
import io.airbyte.cdk.load.data.UnknownType
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.integrations.destination.snowflake.db.SnowflakeColumnNameGenerator
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import jakarta.inject.Singleton
import kotlin.collections.component1
import kotlin.collections.component2
import kotlin.collections.joinToString
import kotlin.collections.map
import kotlin.collections.plus
internal const val NOT_NULL = "NOT NULL"
internal val DEFAULT_COLUMNS =
listOf(
ColumnAndType(
columnName = COLUMN_NAME_AB_RAW_ID,
columnType = "${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL"
),
ColumnAndType(
columnName = COLUMN_NAME_AB_EXTRACTED_AT,
columnType = "${SnowflakeDataType.TIMESTAMP_TZ.typeName} $NOT_NULL"
),
ColumnAndType(
columnName = COLUMN_NAME_AB_META,
columnType = "${SnowflakeDataType.VARIANT.typeName} $NOT_NULL"
),
ColumnAndType(
columnName = COLUMN_NAME_AB_GENERATION_ID,
columnType = SnowflakeDataType.NUMBER.typeName
),
)
internal val RAW_DATA_COLUMN =
ColumnAndType(
columnName = COLUMN_NAME_DATA,
columnType = "${SnowflakeDataType.VARIANT.typeName} $NOT_NULL"
)
internal val RAW_COLUMNS =
listOf(
ColumnAndType(
columnName = COLUMN_NAME_AB_LOADED_AT,
columnType = SnowflakeDataType.TIMESTAMP_TZ.typeName
),
RAW_DATA_COLUMN
)
@Singleton
class SnowflakeColumnUtils(
private val snowflakeConfiguration: SnowflakeConfiguration,
private val snowflakeColumnNameGenerator: SnowflakeColumnNameGenerator,
) {
@VisibleForTesting
internal fun defaultColumns(): List<ColumnAndType> =
if (snowflakeConfiguration.legacyRawTablesOnly) {
DEFAULT_COLUMNS + RAW_COLUMNS
} else {
DEFAULT_COLUMNS
}
internal fun formattedDefaultColumns(): List<ColumnAndType> =
defaultColumns().map {
ColumnAndType(
columnName = formatColumnName(it.columnName, false),
columnType = it.columnType,
)
}
fun getGenerationIdColumnName(): String {
return if (snowflakeConfiguration.legacyRawTablesOnly) {
COLUMN_NAME_AB_GENERATION_ID
} else {
COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName()
}
}
fun getColumnNames(columnNameMapping: ColumnNameMapping): String =
if (snowflakeConfiguration.legacyRawTablesOnly) {
getFormattedDefaultColumnNames(true).joinToString(",")
} else {
(getFormattedDefaultColumnNames(true) +
columnNameMapping.map { (_, actualName) -> actualName.quote() })
.joinToString(",")
}
fun getFormattedDefaultColumnNames(quote: Boolean = false): List<String> =
defaultColumns().map { formatColumnName(it.columnName, quote) }
fun getFormattedColumnNames(
columns: Map<String, FieldType>,
columnNameMapping: ColumnNameMapping,
quote: Boolean = true,
): List<String> =
if (snowflakeConfiguration.legacyRawTablesOnly) {
getFormattedDefaultColumnNames(quote)
} else {
getFormattedDefaultColumnNames(quote) +
columns.map { (fieldName, _) ->
val columnName = columnNameMapping[fieldName] ?: fieldName
if (quote) columnName.quote() else columnName
}
}
fun columnsAndTypes(
columns: Map<String, FieldType>,
columnNameMapping: ColumnNameMapping
): List<ColumnAndType> =
if (snowflakeConfiguration.legacyRawTablesOnly) {
formattedDefaultColumns()
} else {
formattedDefaultColumns() +
columns.map { (fieldName, type) ->
val columnName = columnNameMapping[fieldName] ?: fieldName
val typeName = toDialectType(type.type)
ColumnAndType(
columnName = columnName,
columnType = if (type.nullable) typeName else "$typeName $NOT_NULL",
)
}
}
fun formatColumnName(
columnName: String,
quote: Boolean = true,
): String {
val formattedColumnName =
if (columnName == COLUMN_NAME_DATA) columnName
else snowflakeColumnNameGenerator.getColumnName(columnName).displayName
return if (quote) formattedColumnName.quote() else formattedColumnName
}
fun toDialectType(type: AirbyteType): String =
when (type) {
// Simple types
BooleanType -> SnowflakeDataType.BOOLEAN.typeName
IntegerType -> SnowflakeDataType.NUMBER.typeName
NumberType -> SnowflakeDataType.FLOAT.typeName
StringType -> SnowflakeDataType.VARCHAR.typeName
// Temporal types
DateType -> SnowflakeDataType.DATE.typeName
TimeTypeWithTimezone -> SnowflakeDataType.VARCHAR.typeName
TimeTypeWithoutTimezone -> SnowflakeDataType.TIME.typeName
TimestampTypeWithTimezone -> SnowflakeDataType.TIMESTAMP_TZ.typeName
TimestampTypeWithoutTimezone -> SnowflakeDataType.TIMESTAMP_NTZ.typeName
// Semistructured types
is ArrayType,
ArrayTypeWithoutSchema -> SnowflakeDataType.ARRAY.typeName
is ObjectType,
ObjectTypeWithEmptySchema,
ObjectTypeWithoutSchema -> SnowflakeDataType.OBJECT.typeName
is UnionType -> SnowflakeDataType.VARIANT.typeName
is UnknownType -> SnowflakeDataType.VARIANT.typeName
}
}
data class ColumnAndType(val columnName: String, val columnType: String) {
override fun toString(): String {
return "${columnName.quote()} $columnType"
}
}
/**
* Surrounds the string instance with double quotation marks (e.g. "some string" -> "\"some
* string\"").
*/
fun String.quote() = "$QUOTE$this$QUOTE"

View File

@@ -10,7 +10,7 @@ package io.airbyte.integrations.destination.snowflake.sql
*/ */
enum class SnowflakeDataType(val typeName: String) { enum class SnowflakeDataType(val typeName: String) {
// Numeric types // Numeric types
NUMBER("NUMBER(38,0)"), NUMBER("NUMBER"),
FLOAT("FLOAT"), FLOAT("FLOAT"),
// String & binary types // String & binary types

View File

@@ -4,16 +4,19 @@
package io.airbyte.integrations.destination.snowflake.sql package io.airbyte.integrations.destination.snowflake.sql
import io.airbyte.cdk.load.command.Dedupe import com.google.common.annotations.VisibleForTesting
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.component.ColumnType import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.component.ColumnTypeChange import io.airbyte.cdk.load.component.ColumnTypeChange
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TableName
import io.airbyte.cdk.load.util.UUIDGenerator import io.airbyte.cdk.load.util.UUIDGenerator
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.write.load.CSV_FIELD_SEPARATOR import io.airbyte.integrations.destination.snowflake.write.load.CSV_FIELD_SEPARATOR
@@ -22,6 +25,14 @@ import io.github.oshai.kotlinlogging.KotlinLogging
import jakarta.inject.Singleton import jakarta.inject.Singleton
internal const val COUNT_TOTAL_ALIAS = "TOTAL" internal const val COUNT_TOTAL_ALIAS = "TOTAL"
internal const val NOT_NULL = "NOT NULL"
// Snowflake-compatible (uppercase) versions of the Airbyte meta column names
internal val SNOWFLAKE_AB_RAW_ID = COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName()
internal val SNOWFLAKE_AB_EXTRACTED_AT = COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()
internal val SNOWFLAKE_AB_META = COLUMN_NAME_AB_META.toSnowflakeCompatibleName()
internal val SNOWFLAKE_AB_GENERATION_ID = COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName()
internal val SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN = CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()
private val log = KotlinLogging.logger {} private val log = KotlinLogging.logger {}
@@ -36,80 +47,91 @@ fun String.andLog(): String {
@Singleton @Singleton
class SnowflakeDirectLoadSqlGenerator( class SnowflakeDirectLoadSqlGenerator(
private val columnUtils: SnowflakeColumnUtils,
private val uuidGenerator: UUIDGenerator, private val uuidGenerator: UUIDGenerator,
private val snowflakeConfiguration: SnowflakeConfiguration, private val config: SnowflakeConfiguration,
private val snowflakeSqlNameUtils: SnowflakeSqlNameUtils, private val columnManager: SnowflakeColumnManager,
) { ) {
fun countTable(tableName: TableName): String { fun countTable(tableName: TableName): String {
return "SELECT COUNT(*) AS $COUNT_TOTAL_ALIAS FROM ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog() return "SELECT COUNT(*) AS $COUNT_TOTAL_ALIAS FROM ${fullyQualifiedName(tableName)}".andLog()
} }
fun createNamespace(namespace: String): String { fun createNamespace(namespace: String): String {
return "CREATE SCHEMA IF NOT EXISTS ${snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)}".andLog() return "CREATE SCHEMA IF NOT EXISTS ${fullyQualifiedNamespace(namespace)}".andLog()
} }
fun createTable( fun createTable(
stream: DestinationStream,
tableName: TableName, tableName: TableName,
columnNameMapping: ColumnNameMapping, tableSchema: StreamTableSchema,
replace: Boolean replace: Boolean
): String { ): String {
val finalSchema = tableSchema.columnSchema.finalSchema
val metaColumns = columnManager.getMetaColumns()
// Build column declarations from the meta columns and user schema
val columnDeclarations = val columnDeclarations =
columnUtils buildList {
.columnsAndTypes(stream.schema.asColumns(), columnNameMapping) // Add Airbyte meta columns from the column manager
.joinToString(",\n") metaColumns.forEach { (columnName, columnType) ->
val nullability = if (columnType.nullable) "" else " NOT NULL"
add("${columnName.quote()} ${columnType.type}$nullability")
}
// Add user columns from the munged schema
finalSchema.forEach { (columnName, columnType) ->
val nullability = if (columnType.nullable) "" else " NOT NULL"
add("${columnName.quote()} ${columnType.type}$nullability")
}
}
.joinToString(",\n ")
// Snowflake supports CREATE OR REPLACE TABLE, which is simpler than drop+recreate // Snowflake supports CREATE OR REPLACE TABLE, which is simpler than drop+recreate
val createOrReplace = if (replace) "CREATE OR REPLACE" else "CREATE" val createOrReplace = if (replace) "CREATE OR REPLACE" else "CREATE"
val createTableStatement = val createTableStatement =
""" """
$createOrReplace TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)} ( |$createOrReplace TABLE ${fullyQualifiedName(tableName)} (
$columnDeclarations | $columnDeclarations
) |)
""".trimIndent() """.trimMargin() // Something was tripping up trimIndent so we opt for trimMargin
return createTableStatement.andLog() return createTableStatement.andLog()
} }
fun showColumns(tableName: TableName): String = fun showColumns(tableName: TableName): String =
"SHOW COLUMNS IN TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog() "SHOW COLUMNS IN TABLE ${fullyQualifiedName(tableName)}".andLog()
fun copyTable( fun copyTable(
columnNameMapping: ColumnNameMapping, columnNames: Set<String>,
sourceTableName: TableName, sourceTableName: TableName,
targetTableName: TableName targetTableName: TableName
): String { ): String {
val columnNames = columnUtils.getColumnNames(columnNameMapping) val columnList = columnNames.joinToString(", ") { it.quote() }
return """ return """
INSERT INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} INSERT INTO ${fullyQualifiedName(targetTableName)}
( (
$columnNames $columnList
) )
SELECT SELECT
$columnNames $columnList
FROM ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} FROM ${fullyQualifiedName(sourceTableName)}
""" """
.trimIndent() .trimIndent()
.andLog() .andLog()
} }
fun upsertTable( fun upsertTable(
stream: DestinationStream, tableSchema: StreamTableSchema,
columnNameMapping: ColumnNameMapping,
sourceTableName: TableName, sourceTableName: TableName,
targetTableName: TableName targetTableName: TableName
): String { ): String {
val importType = stream.importType as Dedupe val finalSchema = tableSchema.columnSchema.finalSchema
// Build primary key matching condition // Build primary key matching condition
val pks = tableSchema.getPrimaryKey().flatten()
val pkEquivalent = val pkEquivalent =
if (importType.primaryKey.isNotEmpty()) { if (pks.isNotEmpty()) {
importType.primaryKey.joinToString(" AND ") { fieldPath -> pks.joinToString(" AND ") { columnName ->
val fieldName = fieldPath.first()
val columnName = columnNameMapping[fieldName] ?: fieldName
val targetTableColumnName = "target_table.${columnName.quote()}" val targetTableColumnName = "target_table.${columnName.quote()}"
val newRecordColumnName = "new_record.${columnName.quote()}" val newRecordColumnName = "new_record.${columnName.quote()}"
"""($targetTableColumnName = $newRecordColumnName OR ($targetTableColumnName IS NULL AND $newRecordColumnName IS NULL))""" """($targetTableColumnName = $newRecordColumnName OR ($targetTableColumnName IS NULL AND $newRecordColumnName IS NULL))"""
@@ -120,80 +142,62 @@ class SnowflakeDirectLoadSqlGenerator(
} }
// Build column lists for INSERT and UPDATE // Build column lists for INSERT and UPDATE
val columnList: String = val allColumns = buildList {
columnUtils add(SNOWFLAKE_AB_RAW_ID)
.getFormattedColumnNames( add(SNOWFLAKE_AB_EXTRACTED_AT)
columns = stream.schema.asColumns(), add(SNOWFLAKE_AB_META)
columnNameMapping = columnNameMapping, add(SNOWFLAKE_AB_GENERATION_ID)
quote = false, addAll(finalSchema.keys)
) }
.joinToString(
",\n",
) {
it.quote()
}
val columnList: String = allColumns.joinToString(",\n ") { it.quote() }
val newRecordColumnList: String = val newRecordColumnList: String =
columnUtils allColumns.joinToString(",\n ") { "new_record.${it.quote()}" }
.getFormattedColumnNames(
columns = stream.schema.asColumns(),
columnNameMapping = columnNameMapping,
quote = false,
)
.joinToString(",\n") { "new_record.${it.quote()}" }
// Get deduped records from source // Get deduped records from source
val selectSourceRecords = selectDedupedRecords(stream, sourceTableName, columnNameMapping) val selectSourceRecords = selectDedupedRecords(tableSchema, sourceTableName)
// Build cursor comparison for determining which record is newer // Build cursor comparison for determining which record is newer
val cursorComparison: String val cursorComparison: String
if (importType.cursor.isNotEmpty()) { val cursor = tableSchema.getCursor().firstOrNull()
val cursorFieldName = importType.cursor.first() if (cursor != null) {
val cursor = (columnNameMapping[cursorFieldName] ?: cursorFieldName)
val targetTableCursor = "target_table.${cursor.quote()}" val targetTableCursor = "target_table.${cursor.quote()}"
val newRecordCursor = "new_record.${cursor.quote()}" val newRecordCursor = "new_record.${cursor.quote()}"
cursorComparison = cursorComparison =
""" """
( (
$targetTableCursor < $newRecordCursor $targetTableCursor < $newRecordCursor
OR ($targetTableCursor = $newRecordCursor AND target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}") OR ($targetTableCursor = $newRecordCursor AND target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT")
OR ($targetTableCursor IS NULL AND $newRecordCursor IS NULL AND target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}") OR ($targetTableCursor IS NULL AND $newRecordCursor IS NULL AND target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT")
OR ($targetTableCursor IS NULL AND $newRecordCursor IS $NOT_NULL) OR ($targetTableCursor IS NULL AND $newRecordCursor IS $NOT_NULL)
) )
""".trimIndent() """.trimIndent()
} else { } else {
// No cursor - use extraction timestamp only // No cursor - use extraction timestamp only
cursorComparison = cursorComparison =
"""target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}"""" """target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT""""
} }
// Build column assignments for UPDATE // Build column assignments for UPDATE
val columnAssignments: String = val columnAssignments: String =
columnUtils allColumns.joinToString(",\n ") { column ->
.getFormattedColumnNames( "${column.quote()} = new_record.${column.quote()}"
columns = stream.schema.asColumns(), }
columnNameMapping = columnNameMapping,
quote = false,
)
.joinToString(",\n") { column ->
"${column.quote()} = new_record.${column.quote()}"
}
// Handle CDC deletions based on mode // Handle CDC deletions based on mode
val cdcDeleteClause: String val cdcDeleteClause: String
val cdcSkipInsertClause: String val cdcSkipInsertClause: String
if ( if (
stream.schema.asColumns().containsKey(CDC_DELETED_AT_COLUMN) && finalSchema.containsKey(SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN) &&
snowflakeConfiguration.cdcDeletionMode == CdcDeletionMode.HARD_DELETE config.cdcDeletionMode == CdcDeletionMode.HARD_DELETE
) { ) {
// Execute CDC deletions if there's already a record // Execute CDC deletions if there's already a record
cdcDeleteClause = cdcDeleteClause =
"WHEN MATCHED AND new_record.\"${CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()}\" IS NOT NULL AND $cursorComparison THEN DELETE" "WHEN MATCHED AND new_record.\"${SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN}\" IS NOT NULL AND $cursorComparison THEN DELETE"
// And skip insertion entirely if there's no matching record. // And skip insertion entirely if there's no matching record.
// (This is possible if a single T+D batch contains both an insertion and deletion for // (This is possible if a single T+D batch contains both an insertion and deletion for
// the same PK) // the same PK)
cdcSkipInsertClause = cdcSkipInsertClause = "AND new_record.\"${SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN}\" IS NULL"
"AND new_record.\"${CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()}\" IS NULL"
} else { } else {
cdcDeleteClause = "" cdcDeleteClause = ""
cdcSkipInsertClause = "" cdcSkipInsertClause = ""
@@ -203,35 +207,35 @@ class SnowflakeDirectLoadSqlGenerator(
val mergeStatement = val mergeStatement =
if (cdcDeleteClause.isNotEmpty()) { if (cdcDeleteClause.isNotEmpty()) {
""" """
MERGE INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} AS target_table |MERGE INTO ${fullyQualifiedName(targetTableName)} AS target_table
USING ( |USING (
$selectSourceRecords |$selectSourceRecords
) AS new_record |) AS new_record
ON $pkEquivalent |ON $pkEquivalent
$cdcDeleteClause |$cdcDeleteClause
WHEN MATCHED AND $cursorComparison THEN UPDATE SET |WHEN MATCHED AND $cursorComparison THEN UPDATE SET
$columnAssignments | $columnAssignments
WHEN NOT MATCHED $cdcSkipInsertClause THEN INSERT ( |WHEN NOT MATCHED $cdcSkipInsertClause THEN INSERT (
$columnList | $columnList
) VALUES ( |) VALUES (
$newRecordColumnList | $newRecordColumnList
) |)
""".trimIndent() """.trimMargin()
} else { } else {
""" """
MERGE INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} AS target_table |MERGE INTO ${fullyQualifiedName(targetTableName)} AS target_table
USING ( |USING (
$selectSourceRecords |$selectSourceRecords
) AS new_record |) AS new_record
ON $pkEquivalent |ON $pkEquivalent
WHEN MATCHED AND $cursorComparison THEN UPDATE SET |WHEN MATCHED AND $cursorComparison THEN UPDATE SET
$columnAssignments | $columnAssignments
WHEN NOT MATCHED THEN INSERT ( |WHEN NOT MATCHED THEN INSERT (
$columnList | $columnList
) VALUES ( |) VALUES (
$newRecordColumnList | $newRecordColumnList
) |)
""".trimIndent() """.trimMargin()
} }
return mergeStatement.andLog() return mergeStatement.andLog()
@@ -242,75 +246,66 @@ class SnowflakeDirectLoadSqlGenerator(
* table. Uses ROW_NUMBER() window function to select the most recent record per primary key. * table. Uses ROW_NUMBER() window function to select the most recent record per primary key.
*/ */
private fun selectDedupedRecords( private fun selectDedupedRecords(
stream: DestinationStream, tableSchema: StreamTableSchema,
sourceTableName: TableName, sourceTableName: TableName
columnNameMapping: ColumnNameMapping
): String { ): String {
val columnList: String = val allColumns = buildList {
columnUtils add(SNOWFLAKE_AB_RAW_ID)
.getFormattedColumnNames( add(SNOWFLAKE_AB_EXTRACTED_AT)
columns = stream.schema.asColumns(), add(SNOWFLAKE_AB_META)
columnNameMapping = columnNameMapping, add(SNOWFLAKE_AB_GENERATION_ID)
quote = false, addAll(tableSchema.columnSchema.finalSchema.keys)
) }
.joinToString( val columnList: String = allColumns.joinToString(",\n ") { it.quote() }
",\n",
) {
it.quote()
}
val importType = stream.importType as Dedupe
// Build the primary key list for partitioning // Build the primary key list for partitioning
val pks = tableSchema.getPrimaryKey().flatten()
val pkList = val pkList =
if (importType.primaryKey.isNotEmpty()) { if (pks.isNotEmpty()) {
importType.primaryKey.joinToString(",") { fieldPath -> pks.joinToString(",") { it.quote() }
(columnNameMapping[fieldPath.first()] ?: fieldPath.first()).quote()
}
} else { } else {
// Should not happen as we check this earlier, but handle it defensively // Should not happen as we check this earlier, but handle it defensively
throw IllegalArgumentException("Cannot deduplicate without primary key") throw IllegalArgumentException("Cannot deduplicate without primary key")
} }
// Build cursor order clause for sorting within each partition // Build cursor order clause for sorting within each partition
val cursor = tableSchema.getCursor().firstOrNull()
val cursorOrderClause = val cursorOrderClause =
if (importType.cursor.isNotEmpty()) { if (cursor != null) {
val columnName = "${cursor.quote()} DESC NULLS LAST,"
(columnNameMapping[importType.cursor.first()] ?: importType.cursor.first())
.quote()
"$columnName DESC NULLS LAST,"
} else { } else {
"" ""
} }
return """ return """
WITH records AS ( | WITH records AS (
SELECT | SELECT
$columnList | $columnList
FROM ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} | FROM ${fullyQualifiedName(sourceTableName)}
), numbered_rows AS ( | ), numbered_rows AS (
SELECT *, ROW_NUMBER() OVER ( | SELECT *, ROW_NUMBER() OVER (
PARTITION BY $pkList ORDER BY $cursorOrderClause "${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" DESC | PARTITION BY $pkList ORDER BY $cursorOrderClause "$SNOWFLAKE_AB_EXTRACTED_AT" DESC
) AS row_number | ) AS row_number
FROM records | FROM records
) | )
SELECT $columnList | SELECT $columnList
FROM numbered_rows | FROM numbered_rows
WHERE row_number = 1 | WHERE row_number = 1
""" """
.trimIndent() .trimMargin()
.andLog() .andLog()
} }
fun dropTable(tableName: TableName): String { fun dropTable(tableName: TableName): String {
return "DROP TABLE IF EXISTS ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog() return "DROP TABLE IF EXISTS ${fullyQualifiedName(tableName)}".andLog()
} }
fun getGenerationId( fun getGenerationId(
tableName: TableName, tableName: TableName,
): String { ): String {
return """ return """
SELECT "${columnUtils.getGenerationIdColumnName()}" SELECT "${columnManager.getGenerationIdColumnName()}"
FROM ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)} FROM ${fullyQualifiedName(tableName)}
LIMIT 1 LIMIT 1
""" """
.trimIndent() .trimIndent()
@@ -318,12 +313,12 @@ class SnowflakeDirectLoadSqlGenerator(
} }
fun createSnowflakeStage(tableName: TableName): String { fun createSnowflakeStage(tableName: TableName): String {
val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName) val stageName = fullyQualifiedStageName(tableName)
return "CREATE STAGE IF NOT EXISTS $stageName".andLog() return "CREATE STAGE IF NOT EXISTS $stageName".andLog()
} }
fun putInStage(tableName: TableName, tempFilePath: String): String { fun putInStage(tableName: TableName, tempFilePath: String): String {
val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true) val stageName = fullyQualifiedStageName(tableName, true)
return """ return """
PUT 'file://$tempFilePath' '@$stageName' PUT 'file://$tempFilePath' '@$stageName'
AUTO_COMPRESS = FALSE AUTO_COMPRESS = FALSE
@@ -334,35 +329,45 @@ class SnowflakeDirectLoadSqlGenerator(
.andLog() .andLog()
} }
fun copyFromStage(tableName: TableName, filename: String): String { fun copyFromStage(
val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true) tableName: TableName,
filename: String,
columnNames: List<String>? = null
): String {
val stageName = fullyQualifiedStageName(tableName, true)
val columnList =
columnNames?.let { names -> "(${names.joinToString(", ") { it.quote() }})" } ?: ""
return """ return """
COPY INTO ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)} |COPY INTO ${fullyQualifiedName(tableName)}$columnList
FROM '@$stageName' |FROM '@$stageName'
FILE_FORMAT = ( |FILE_FORMAT = (
TYPE = 'CSV' | TYPE = 'CSV'
COMPRESSION = GZIP | COMPRESSION = GZIP
FIELD_DELIMITER = '$CSV_FIELD_SEPARATOR' | FIELD_DELIMITER = '$CSV_FIELD_SEPARATOR'
RECORD_DELIMITER = '$CSV_LINE_DELIMITER' | RECORD_DELIMITER = '$CSV_LINE_DELIMITER'
FIELD_OPTIONALLY_ENCLOSED_BY = '"' | FIELD_OPTIONALLY_ENCLOSED_BY = '"'
TRIM_SPACE = TRUE | TRIM_SPACE = TRUE
ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE | ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE
REPLACE_INVALID_CHARACTERS = TRUE | REPLACE_INVALID_CHARACTERS = TRUE
ESCAPE = NONE | ESCAPE = NONE
ESCAPE_UNENCLOSED_FIELD = NONE | ESCAPE_UNENCLOSED_FIELD = NONE
) |)
ON_ERROR = 'ABORT_STATEMENT' |ON_ERROR = 'ABORT_STATEMENT'
PURGE = TRUE |PURGE = TRUE
files = ('$filename') |files = ('$filename')
""" """
.trimIndent() .trimMargin()
.andLog() .andLog()
} }
fun swapTableWith(sourceTableName: TableName, targetTableName: TableName): String { fun swapTableWith(sourceTableName: TableName, targetTableName: TableName): String {
return """ return """
ALTER TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} SWAP WITH ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} ALTER TABLE ${fullyQualifiedName(sourceTableName)} SWAP WITH ${
fullyQualifiedName(
targetTableName,
)
}
""" """
.trimIndent() .trimIndent()
.andLog() .andLog()
@@ -372,7 +377,11 @@ class SnowflakeDirectLoadSqlGenerator(
// Snowflake RENAME TO only accepts the table name, not a fully qualified name // Snowflake RENAME TO only accepts the table name, not a fully qualified name
// The renamed table stays in the same schema // The renamed table stays in the same schema
return """ return """
ALTER TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} RENAME TO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} ALTER TABLE ${fullyQualifiedName(sourceTableName)} RENAME TO ${
fullyQualifiedName(
targetTableName,
)
}
""" """
.trimIndent() .trimIndent()
.andLog() .andLog()
@@ -382,7 +391,7 @@ class SnowflakeDirectLoadSqlGenerator(
schemaName: String, schemaName: String,
tableName: String, tableName: String,
): String = ): String =
"""DESCRIBE TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(TableName(schemaName, tableName))}""".andLog() """DESCRIBE TABLE ${fullyQualifiedName(TableName(schemaName, tableName))}""".andLog()
fun alterTable( fun alterTable(
tableName: TableName, tableName: TableName,
@@ -391,14 +400,14 @@ class SnowflakeDirectLoadSqlGenerator(
modifiedColumns: Map<String, ColumnTypeChange>, modifiedColumns: Map<String, ColumnTypeChange>,
): Set<String> { ): Set<String> {
val clauses = mutableSetOf<String>() val clauses = mutableSetOf<String>()
val prettyTableName = snowflakeSqlNameUtils.fullyQualifiedName(tableName) val prettyTableName = fullyQualifiedName(tableName)
addedColumns.forEach { (name, columnType) -> addedColumns.forEach { (name, columnType) ->
clauses.add( clauses.add(
// Note that we intentionally don't set NOT NULL. // Note that we intentionally don't set NOT NULL.
// We're adding a new column, and we don't know what constitutes a reasonable // We're adding a new column, and we don't know what constitutes a reasonable
// default value for preexisting records. // default value for preexisting records.
// So we add the column as nullable. // So we add the column as nullable.
"ALTER TABLE $prettyTableName ADD COLUMN ${name.quote()} ${columnType.type};".andLog() "ALTER TABLE $prettyTableName ADD COLUMN ${name.quote()} ${columnType.type};".andLog(),
) )
} }
deletedColumns.forEach { deletedColumns.forEach {
@@ -412,35 +421,34 @@ class SnowflakeDirectLoadSqlGenerator(
val tempColumn = "${name}_${uuidGenerator.v4()}" val tempColumn = "${name}_${uuidGenerator.v4()}"
clauses.add( clauses.add(
// As above: we add the column as nullable. // As above: we add the column as nullable.
"ALTER TABLE $prettyTableName ADD COLUMN ${tempColumn.quote()} ${typeChange.newType.type};".andLog() "ALTER TABLE $prettyTableName ADD COLUMN ${tempColumn.quote()} ${typeChange.newType.type};".andLog(),
) )
clauses.add( clauses.add(
"UPDATE $prettyTableName SET ${tempColumn.quote()} = CAST(${name.quote()} AS ${typeChange.newType.type});".andLog() "UPDATE $prettyTableName SET ${tempColumn.quote()} = CAST(${name.quote()} AS ${typeChange.newType.type});".andLog(),
) )
val backupColumn = "${tempColumn}_backup" val backupColumn = "${tempColumn}_backup"
clauses.add( clauses.add(
""" """
ALTER TABLE $prettyTableName ALTER TABLE $prettyTableName
RENAME COLUMN "$name" TO "$backupColumn"; RENAME COLUMN "$name" TO "$backupColumn";
""".trimIndent() """.trimIndent(),
) )
clauses.add( clauses.add(
""" """
ALTER TABLE $prettyTableName ALTER TABLE $prettyTableName
RENAME COLUMN "$tempColumn" TO "$name"; RENAME COLUMN "$tempColumn" TO "$name";
""".trimIndent() """.trimIndent(),
) )
clauses.add( clauses.add(
"ALTER TABLE $prettyTableName DROP COLUMN ${backupColumn.quote()};".andLog() "ALTER TABLE $prettyTableName DROP COLUMN ${backupColumn.quote()};".andLog(),
) )
} else if (!typeChange.originalType.nullable && typeChange.newType.nullable) { } else if (!typeChange.originalType.nullable && typeChange.newType.nullable) {
// If the type is unchanged, we can change a column from NOT NULL to nullable. // If the type is unchanged, we can change a column from NOT NULL to nullable.
// But we'll never do the reverse, because there's a decent chance that historical // But we'll never do the reverse, because there's a decent chance that historical
// records // records had null values.
// had null values.
// Users can always manually ALTER COLUMN ... SET NOT NULL if they want. // Users can always manually ALTER COLUMN ... SET NOT NULL if they want.
clauses.add( clauses.add(
"""ALTER TABLE $prettyTableName ALTER COLUMN "$name" DROP NOT NULL;""".andLog() """ALTER TABLE $prettyTableName ALTER COLUMN "$name" DROP NOT NULL;""".andLog(),
) )
} else { } else {
log.info { log.info {
@@ -450,4 +458,45 @@ class SnowflakeDirectLoadSqlGenerator(
} }
return clauses return clauses
} }
@VisibleForTesting
fun fullyQualifiedName(tableName: TableName): String =
combineParts(listOf(getDatabaseName(), tableName.namespace, tableName.name))
@VisibleForTesting
fun fullyQualifiedNamespace(namespace: String) =
combineParts(listOf(getDatabaseName(), namespace))
@VisibleForTesting
fun fullyQualifiedStageName(tableName: TableName, escape: Boolean = false): String {
val currentTableName =
if (escape) {
tableName.name
} else {
tableName.name
}
return combineParts(
parts =
listOf(
getDatabaseName(),
tableName.namespace,
"$STAGE_NAME_PREFIX$currentTableName",
),
escape = escape,
)
}
@VisibleForTesting
internal fun combineParts(parts: List<String>, escape: Boolean = false): String =
parts
.map { if (escape) sqlEscape(it) else it }
.joinToString(separator = ".") {
if (!it.startsWith(QUOTE)) {
"$QUOTE$it$QUOTE"
} else {
it
}
}
private fun getDatabaseName() = config.database.toSnowflakeCompatibleName()
} }

View File

@@ -0,0 +1,29 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
const val STAGE_NAME_PREFIX = "airbyte_stage_"
internal const val QUOTE: String = "\""
fun sqlEscape(part: String) = part.replace("\\", "\\\\").replace("'", "\\'").replace("\"", "\\\"")
/**
* Surrounds the string instance with double quotation marks (e.g. "some string" -> "\"some
* string\"").
*/
fun String.quote() = "$QUOTE$this$QUOTE"
/**
* Escapes double-quotes in a JSON identifier by doubling them. This is legacy -- I don't know why
* this would be necessary but no harm in keeping it, so I am keeping it.
*
* @return The escaped identifier.
*/
fun escapeJsonIdentifier(identifier: String): String {
// Note that we don't need to escape backslashes here!
// The only special character in an identifier is the double-quote, which needs to be
// doubled.
return identifier.replace(QUOTE, "$QUOTE$QUOTE")
}

View File

@@ -1,57 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
import io.airbyte.cdk.load.table.TableName
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import jakarta.inject.Singleton
const val STAGE_NAME_PREFIX = "airbyte_stage_"
internal const val QUOTE: String = "\""
fun sqlEscape(part: String) = part.replace("\\", "\\\\").replace("'", "\\'").replace("\"", "\\\"")
@Singleton
class SnowflakeSqlNameUtils(
private val snowflakeConfiguration: SnowflakeConfiguration,
) {
fun fullyQualifiedName(tableName: TableName): String =
combineParts(listOf(getDatabaseName(), tableName.namespace, tableName.name))
fun fullyQualifiedNamespace(namespace: String) =
combineParts(listOf(getDatabaseName(), namespace))
fun fullyQualifiedStageName(tableName: TableName, escape: Boolean = false): String {
val currentTableName =
if (escape) {
tableName.name
} else {
tableName.name
}
return combineParts(
parts =
listOf(
getDatabaseName(),
tableName.namespace,
"$STAGE_NAME_PREFIX$currentTableName"
),
escape = escape,
)
}
fun combineParts(parts: List<String>, escape: Boolean = false): String =
parts
.map { if (escape) sqlEscape(it) else it }
.joinToString(separator = ".") {
if (!it.startsWith(QUOTE)) {
"$QUOTE$it$QUOTE"
} else {
it
}
}
private fun getDatabaseName() = snowflakeConfiguration.database.toSnowflakeCompatibleName()
}

View File

@@ -6,38 +6,39 @@ package io.airbyte.integrations.destination.snowflake.write
import io.airbyte.cdk.SystemErrorException import io.airbyte.cdk.SystemErrorException
import io.airbyte.cdk.load.command.Dedupe import io.airbyte.cdk.load.command.Dedupe
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationStream import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.orchestration.db.DatabaseInitialStatusGatherer import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator import io.airbyte.cdk.load.table.DatabaseInitialStatusGatherer
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadInitialStatus import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendStreamLoader import io.airbyte.cdk.load.table.directload.DirectLoadInitialStatus
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendTruncateStreamLoader import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendStreamLoader
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableDedupStreamLoader import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendTruncateStreamLoader
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableDedupTruncateStreamLoader import io.airbyte.cdk.load.table.directload.DirectLoadTableDedupStreamLoader
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig import io.airbyte.cdk.load.table.directload.DirectLoadTableDedupTruncateStreamLoader
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.write.DestinationWriter import io.airbyte.cdk.load.write.DestinationWriter
import io.airbyte.cdk.load.write.StreamLoader import io.airbyte.cdk.load.write.StreamLoader
import io.airbyte.cdk.load.write.StreamStateStore import io.airbyte.cdk.load.write.StreamStateStore
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
import jakarta.inject.Singleton import jakarta.inject.Singleton
@Singleton @Singleton
class SnowflakeWriter( class SnowflakeWriter(
private val names: TableCatalog, private val catalog: DestinationCatalog,
private val stateGatherer: DatabaseInitialStatusGatherer<DirectLoadInitialStatus>, private val stateGatherer: DatabaseInitialStatusGatherer<DirectLoadInitialStatus>,
private val streamStateStore: StreamStateStore<DirectLoadTableExecutionConfig>, private val streamStateStore: StreamStateStore<DirectLoadTableExecutionConfig>,
private val snowflakeClient: SnowflakeAirbyteClient, private val snowflakeClient: SnowflakeAirbyteClient,
private val tempTableNameGenerator: TempTableNameGenerator,
private val snowflakeConfiguration: SnowflakeConfiguration, private val snowflakeConfiguration: SnowflakeConfiguration,
private val tempTableNameGenerator: TempTableNameGenerator,
) : DestinationWriter { ) : DestinationWriter {
private lateinit var initialStatuses: Map<DestinationStream, DirectLoadInitialStatus> private lateinit var initialStatuses: Map<DestinationStream, DirectLoadInitialStatus>
override suspend fun setup() { override suspend fun setup() {
names.values catalog.streams
.map { (tableNames, _) -> tableNames.finalTableName!!.namespace } .map { it.tableSchema.tableNames.finalTableName!!.namespace }
.toSet() .toSet()
.forEach { snowflakeClient.createNamespace(it) } .forEach { snowflakeClient.createNamespace(it) }
@@ -45,15 +46,15 @@ class SnowflakeWriter(
escapeJsonIdentifier(snowflakeConfiguration.internalTableSchema) escapeJsonIdentifier(snowflakeConfiguration.internalTableSchema)
) )
initialStatuses = stateGatherer.gatherInitialStatus(names) initialStatuses = stateGatherer.gatherInitialStatus()
} }
override fun createStreamLoader(stream: DestinationStream): StreamLoader { override fun createStreamLoader(stream: DestinationStream): StreamLoader {
val initialStatus = initialStatuses[stream]!! val initialStatus = initialStatuses[stream]!!
val tableNameInfo = names[stream]!! val realTableName = stream.tableSchema.tableNames.finalTableName!!
val realTableName = tableNameInfo.tableNames.finalTableName!! val tempTableName = stream.tableSchema.tableNames.tempTableName!!
val tempTableName = tempTableNameGenerator.generate(realTableName) val columnNameMapping =
val columnNameMapping = tableNameInfo.columnNameMapping ColumnNameMapping(stream.tableSchema.columnSchema.inputToFinalColumnNames)
return when (stream.minimumGenerationId) { return when (stream.minimumGenerationId) {
0L -> 0L ->
when (stream.importType) { when (stream.importType) {

View File

@@ -9,11 +9,12 @@ import de.siegmar.fastcsv.writer.CsvWriter
import de.siegmar.fastcsv.writer.LineDelimiter import de.siegmar.fastcsv.writer.LineDelimiter
import de.siegmar.fastcsv.writer.QuoteStrategies import de.siegmar.fastcsv.writer.QuoteStrategies
import io.airbyte.cdk.load.data.AirbyteValue import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.table.TableName import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.QUOTE import io.airbyte.integrations.destination.snowflake.sql.QUOTE
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.github.oshai.kotlinlogging.KotlinLogging import io.github.oshai.kotlinlogging.KotlinLogging
import java.io.File import java.io.File
import java.io.OutputStream import java.io.OutputStream
@@ -36,10 +37,11 @@ private const val CSV_WRITER_BUFFER_SIZE = 1024 * 1024 // 1 MB
class SnowflakeInsertBuffer( class SnowflakeInsertBuffer(
private val tableName: TableName, private val tableName: TableName,
val columns: LinkedHashMap<String, String>,
private val snowflakeClient: SnowflakeAirbyteClient, private val snowflakeClient: SnowflakeAirbyteClient,
val snowflakeConfiguration: SnowflakeConfiguration, val snowflakeConfiguration: SnowflakeConfiguration,
val snowflakeColumnUtils: SnowflakeColumnUtils, val columnSchema: ColumnSchema,
private val columnManager: SnowflakeColumnManager,
private val snowflakeRecordFormatter: SnowflakeRecordFormatter,
private val flushLimit: Int = DEFAULT_FLUSH_LIMIT, private val flushLimit: Int = DEFAULT_FLUSH_LIMIT,
) { ) {
@@ -57,12 +59,6 @@ class SnowflakeInsertBuffer(
.lineDelimiter(CSV_LINE_DELIMITER) .lineDelimiter(CSV_LINE_DELIMITER)
.quoteStrategy(QuoteStrategies.REQUIRED) .quoteStrategy(QuoteStrategies.REQUIRED)
private val snowflakeRecordFormatter: SnowflakeRecordFormatter =
when (snowflakeConfiguration.legacyRawTablesOnly) {
true -> SnowflakeRawRecordFormatter(columns, snowflakeColumnUtils)
else -> SnowflakeSchemaRecordFormatter(columns, snowflakeColumnUtils)
}
fun accumulate(recordFields: Map<String, AirbyteValue>) { fun accumulate(recordFields: Map<String, AirbyteValue>) {
if (csvFilePath == null) { if (csvFilePath == null) {
val csvFile = createCsvFile() val csvFile = createCsvFile()
@@ -92,7 +88,9 @@ class SnowflakeInsertBuffer(
"Copying staging data into ${tableName.toPrettyString(quote = QUOTE)}..." "Copying staging data into ${tableName.toPrettyString(quote = QUOTE)}..."
} }
// Finally, copy the data from the staging table to the final table // Finally, copy the data from the staging table to the final table
snowflakeClient.copyFromStage(tableName, filePath.fileName.toString()) // Pass column names to ensure correct mapping even after ALTER TABLE operations
val columnNames = columnManager.getTableColumnNames(columnSchema)
snowflakeClient.copyFromStage(tableName, filePath.fileName.toString(), columnNames)
logger.info { logger.info {
"Finished insert of $recordCount row(s) into ${tableName.toPrettyString(quote = QUOTE)}." "Finished insert of $recordCount row(s) into ${tableName.toPrettyString(quote = QUOTE)}."
} }
@@ -117,7 +115,9 @@ class SnowflakeInsertBuffer(
private fun writeToCsvFile(record: Map<String, AirbyteValue>) { private fun writeToCsvFile(record: Map<String, AirbyteValue>) {
csvWriter?.let { csvWriter?.let {
it.writeRecord(snowflakeRecordFormatter.format(record).map { col -> col.toString() }) it.writeRecord(
snowflakeRecordFormatter.format(record, columnSchema).map { col -> col.toString() }
)
recordCount++ recordCount++
if ((recordCount % flushLimit) == 0) { if ((recordCount % flushLimit) == 0) {
it.flush() it.flush()

View File

@@ -8,99 +8,62 @@ import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.NullValue import io.airbyte.cdk.load.data.NullValue
import io.airbyte.cdk.load.data.StringValue import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.data.csv.toCsvValue import io.airbyte.cdk.load.data.csv.toCsvValue
import io.airbyte.cdk.load.message.Meta import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.util.Jsons import io.airbyte.cdk.load.util.Jsons
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
interface SnowflakeRecordFormatter { interface SnowflakeRecordFormatter {
fun format(record: Map<String, AirbyteValue>): List<Any> fun format(record: Map<String, AirbyteValue>, columnSchema: ColumnSchema): List<Any>
} }
class SnowflakeSchemaRecordFormatter( class SnowflakeSchemaRecordFormatter : SnowflakeRecordFormatter {
private val columns: LinkedHashMap<String, String>, override fun format(record: Map<String, AirbyteValue>, columnSchema: ColumnSchema): List<Any> {
val snowflakeColumnUtils: SnowflakeColumnUtils, val result = mutableListOf<Any>()
) : SnowflakeRecordFormatter { val userColumns = columnSchema.finalSchema.keys
private val airbyteColumnNames = // WARNING: MUST match the order defined in SnowflakeColumnManager#getTableColumnNames
snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet() //
// Why don't we just use that here? Well, unlike the user fields, the meta fields on the
// record are not munged for the destination. So we must access the values for those columns
// using the original lowercase meta key.
result.add(record[COLUMN_NAME_AB_RAW_ID].toCsvValue())
result.add(record[COLUMN_NAME_AB_EXTRACTED_AT].toCsvValue())
result.add(record[COLUMN_NAME_AB_META].toCsvValue())
result.add(record[COLUMN_NAME_AB_GENERATION_ID].toCsvValue())
override fun format(record: Map<String, AirbyteValue>): List<Any> = // Add user columns from the final schema
columns.map { (columnName, _) -> userColumns.forEach { columnName -> result.add(record[columnName].toCsvValue()) }
/*
* Meta columns are forced to uppercase for backwards compatibility with previous return result
* versions of the destination. Therefore, convert the column to lowercase so }
* that it can match the constants, which use the lowercase version of the meta
* column names.
*/
if (airbyteColumnNames.contains(columnName)) {
record[columnName.lowercase()].toCsvValue()
} else {
record.keys
// The columns retrieved from Snowflake do not have any escaping applied.
// Therefore, re-apply the compatible name escaping to the name of the
// columns retrieved from Snowflake. The record keys should already have
// been escaped by the CDK before arriving at the aggregate, so no need
// to escape again here.
.find { it == columnName.toSnowflakeCompatibleName() }
?.let { record[it].toCsvValue() }
?: ""
}
}
} }
class SnowflakeRawRecordFormatter( class SnowflakeRawRecordFormatter : SnowflakeRecordFormatter {
columns: LinkedHashMap<String, String>,
val snowflakeColumnUtils: SnowflakeColumnUtils,
) : SnowflakeRecordFormatter {
private val columns = columns.keys
private val airbyteColumnNames = override fun format(record: Map<String, AirbyteValue>, columnSchema: ColumnSchema): List<Any> =
snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet()
override fun format(record: Map<String, AirbyteValue>): List<Any> =
toOutputRecord(record.toMutableMap()) toOutputRecord(record.toMutableMap())
private fun toOutputRecord(record: MutableMap<String, AirbyteValue>): List<Any> { private fun toOutputRecord(record: MutableMap<String, AirbyteValue>): List<Any> {
val outputRecord = mutableListOf<Any>() val outputRecord = mutableListOf<Any>()
// Copy the Airbyte metadata columns to the raw output, removing each val mutableRecord = record.toMutableMap()
// one from the record to avoid duplicates in the "data" field
columns // Add meta columns in order (except _airbyte_data which we handle specially)
.filter { airbyteColumnNames.contains(it) && it != Meta.COLUMN_NAME_DATA } outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_RAW_ID)?.toCsvValue() ?: "")
.forEach { column -> safeAddToOutput(column, record, outputRecord) } outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_EXTRACTED_AT)?.toCsvValue() ?: "")
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_META)?.toCsvValue() ?: "")
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_GENERATION_ID)?.toCsvValue() ?: "")
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_LOADED_AT)?.toCsvValue() ?: "")
// Do not output null values in the JSON raw output // Do not output null values in the JSON raw output
val filteredRecord = record.filter { (_, v) -> v !is NullValue } val filteredRecord = mutableRecord.filter { (_, v) -> v !is NullValue }
// Convert all the remaining columns in the record to a JSON document stored in the "data"
// column. Add it in the same position as the _airbyte_data column in the column list to // Convert all the remaining columns to a JSON document stored in the "data" column
// ensure it is inserted into the proper column in the table. outputRecord.add(StringValue(Jsons.writeValueAsString(filteredRecord)).toCsvValue())
insert(
columns.indexOf(Meta.COLUMN_NAME_DATA),
StringValue(Jsons.writeValueAsString(filteredRecord)).toCsvValue(),
outputRecord
)
return outputRecord return outputRecord
} }
private fun safeAddToOutput(
key: String,
record: MutableMap<String, AirbyteValue>,
output: MutableList<Any>
) {
val extractedValue = record.remove(key)
// Ensure that the data is inserted into the list at the same position as the column
insert(columns.indexOf(key), extractedValue?.toCsvValue() ?: "", output)
}
private fun insert(index: Int, value: Any, list: MutableList<Any>) {
/*
* Attempt to insert the value into the proper order in the list. If the index
* is already present in the list, use the add(index, element) method to insert it
* into the proper order and push everything to the right. If the index is at the
* end of the list, just use add(element) to insert it at the end. If the index
* is further beyond the end of the list, throw an exception as that should not occur.
*/
if (index < list.size) list.add(index, value)
else if (index == list.size || index == list.size + 1) list.add(value)
else throw IndexOutOfBoundsException()
}
} }

View File

@@ -1,25 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.write.transform
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.dataflow.transform.ColumnNameMapper
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import jakarta.inject.Singleton
@Singleton
class SnowflakeColumnNameMapper(
private val catalogInfo: TableCatalog,
private val snowflakeConfiguration: SnowflakeConfiguration,
) : ColumnNameMapper {
override fun getMappedColumnName(stream: DestinationStream, columnName: String): String {
if (snowflakeConfiguration.legacyRawTablesOnly == true) {
return columnName
} else {
return catalogInfo.getMappedColumnName(stream, columnName)!!
}
}
}

View File

@@ -7,9 +7,11 @@ package io.airbyte.integrations.destination.snowflake.component
import io.airbyte.cdk.load.component.TableOperationsFixtures import io.airbyte.cdk.load.component.TableOperationsFixtures
import io.airbyte.cdk.load.component.TableOperationsSuite import io.airbyte.cdk.load.component.TableOperationsSuite
import io.airbyte.cdk.load.message.Meta import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.TableSchemaFactory
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.idTestWithCdcMapping import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.idTestWithCdcMapping
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.testMapping import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.testMapping
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeTestTableOperationsClient
import io.micronaut.test.extensions.junit5.annotation.MicronautTest import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.junit.jupiter.api.parallel.Execution import org.junit.jupiter.api.parallel.Execution
@@ -20,6 +22,7 @@ import org.junit.jupiter.api.parallel.ExecutionMode
class SnowflakeTableOperationsTest( class SnowflakeTableOperationsTest(
override val client: SnowflakeAirbyteClient, override val client: SnowflakeAirbyteClient,
override val testClient: SnowflakeTestTableOperationsClient, override val testClient: SnowflakeTestTableOperationsClient,
override val schemaFactory: TableSchemaFactory,
) : TableOperationsSuite { ) : TableOperationsSuite {
override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() } override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() }

View File

@@ -9,13 +9,15 @@ import io.airbyte.cdk.load.component.TableSchemaEvolutionFixtures
import io.airbyte.cdk.load.component.TableSchemaEvolutionSuite import io.airbyte.cdk.load.component.TableSchemaEvolutionSuite
import io.airbyte.cdk.load.data.StringValue import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.message.Meta import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.schema.TableSchemaFactory
import io.airbyte.cdk.load.table.ColumnNameMapping import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.util.serializeToString import io.airbyte.cdk.load.util.serializeToString
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.allTypesColumnNameMapping import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.allTypesColumnNameMapping
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.allTypesTableSchema import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.allTypesTableSchema
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.idAndTestMapping import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.idAndTestMapping
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.testMapping import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.testMapping
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeTestTableOperationsClient
import io.micronaut.test.extensions.junit5.annotation.MicronautTest import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.junit.jupiter.api.parallel.Execution import org.junit.jupiter.api.parallel.Execution
@@ -27,6 +29,7 @@ class SnowflakeTableSchemaEvolutionTest(
override val client: SnowflakeAirbyteClient, override val client: SnowflakeAirbyteClient,
override val opsClient: SnowflakeAirbyteClient, override val opsClient: SnowflakeAirbyteClient,
override val testClient: SnowflakeTestTableOperationsClient, override val testClient: SnowflakeTestTableOperationsClient,
override val schemaFactory: TableSchemaFactory,
) : TableSchemaEvolutionSuite { ) : TableSchemaEvolutionSuite {
override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() } override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() }

View File

@@ -2,7 +2,7 @@
* Copyright (c) 2025 Airbyte, Inc., all rights reserved. * Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/ */
package io.airbyte.integrations.destination.snowflake.component package io.airbyte.integrations.destination.snowflake.component.config
import io.airbyte.cdk.load.component.ColumnType import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.component.TableOperationsFixtures import io.airbyte.cdk.load.component.TableOperationsFixtures
@@ -33,6 +33,8 @@ object SnowflakeComponentTestFixtures {
"TIME_NTZ" to ColumnType("TIME", true), "TIME_NTZ" to ColumnType("TIME", true),
"ARRAY" to ColumnType("ARRAY", true), "ARRAY" to ColumnType("ARRAY", true),
"OBJECT" to ColumnType("OBJECT", true), "OBJECT" to ColumnType("OBJECT", true),
"UNION" to ColumnType("VARIANT", true),
"LEGACY_UNION" to ColumnType("VARIANT", true),
"UNKNOWN" to ColumnType("VARIANT", true), "UNKNOWN" to ColumnType("VARIANT", true),
) )
) )

View File

@@ -2,7 +2,7 @@
* Copyright (c) 2025 Airbyte, Inc., all rights reserved. * Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/ */
package io.airbyte.integrations.destination.snowflake.component package io.airbyte.integrations.destination.snowflake.component.config
import io.airbyte.cdk.load.component.config.TestConfigLoader.loadTestConfig import io.airbyte.cdk.load.component.config.TestConfigLoader.loadTestConfig
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration

View File

@@ -2,22 +2,24 @@
* Copyright (c) 2025 Airbyte, Inc., all rights reserved. * Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/ */
package io.airbyte.integrations.destination.snowflake.component package io.airbyte.integrations.destination.snowflake.component.config
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.component.TestTableOperationsClient import io.airbyte.cdk.load.component.TestTableOperationsClient
import io.airbyte.cdk.load.data.AirbyteValue import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.dataflow.state.PartitionKey import io.airbyte.cdk.load.dataflow.state.PartitionKey
import io.airbyte.cdk.load.dataflow.transform.RecordDTO import io.airbyte.cdk.load.dataflow.transform.RecordDTO
import io.airbyte.cdk.load.table.TableName import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.util.Jsons import io.airbyte.cdk.load.util.Jsons
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.client.execute import io.airbyte.integrations.destination.snowflake.client.execute
import io.airbyte.integrations.destination.snowflake.dataflow.SnowflakeAggregate import io.airbyte.integrations.destination.snowflake.dataflow.SnowflakeAggregate
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils
import io.airbyte.integrations.destination.snowflake.sql.andLog import io.airbyte.integrations.destination.snowflake.sql.andLog
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
import io.micronaut.context.annotation.Requires import io.micronaut.context.annotation.Requires
import jakarta.inject.Singleton import jakarta.inject.Singleton
import java.time.format.DateTimeFormatter import java.time.format.DateTimeFormatter
@@ -29,25 +31,40 @@ import net.snowflake.client.jdbc.SnowflakeTimestampWithTimezone
class SnowflakeTestTableOperationsClient( class SnowflakeTestTableOperationsClient(
private val client: SnowflakeAirbyteClient, private val client: SnowflakeAirbyteClient,
private val dataSource: DataSource, private val dataSource: DataSource,
private val snowflakeSqlNameUtils: SnowflakeSqlNameUtils, private val sqlGenerator: SnowflakeDirectLoadSqlGenerator,
private val snowflakeColumnUtils: SnowflakeColumnUtils,
private val snowflakeConfiguration: SnowflakeConfiguration, private val snowflakeConfiguration: SnowflakeConfiguration,
private val columnManager: SnowflakeColumnManager,
private val snowflakeRecordFormatter: SnowflakeRecordFormatter,
) : TestTableOperationsClient { ) : TestTableOperationsClient {
override suspend fun dropNamespace(namespace: String) { override suspend fun dropNamespace(namespace: String) {
dataSource.execute( dataSource.execute(
"DROP SCHEMA IF EXISTS ${snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)}".andLog() "DROP SCHEMA IF EXISTS ${sqlGenerator.fullyQualifiedNamespace(namespace)}".andLog()
) )
} }
override suspend fun insertRecords(table: TableName, records: List<Map<String, AirbyteValue>>) { override suspend fun insertRecords(table: TableName, records: List<Map<String, AirbyteValue>>) {
// TODO: we should just pass a proper column schema
// Since we don't pass in a proper column schema, we have to recreate one here
// Fetch the columns and filter out the meta columns so we're just looking at user columns
val columnTypes =
client.describeTable(table).filterNot {
columnManager.getMetaColumnNames().contains(it.key)
}
val columnSchema =
io.airbyte.cdk.load.schema.model.ColumnSchema(
inputToFinalColumnNames = columnTypes.keys.associateWith { it },
finalSchema = columnTypes.mapValues { (_, _) -> ColumnType("", true) },
inputSchema = emptyMap() // Not needed for insert buffer
)
val a = val a =
SnowflakeAggregate( SnowflakeAggregate(
SnowflakeInsertBuffer( SnowflakeInsertBuffer(
table, tableName = table,
client.describeTable(table), snowflakeClient = client,
client, snowflakeConfiguration = snowflakeConfiguration,
snowflakeConfiguration, columnSchema = columnSchema,
snowflakeColumnUtils, columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
) )
) )
records.forEach { a.accept(RecordDTO(it, PartitionKey(""), 0, 0)) } records.forEach { a.accept(RecordDTO(it, PartitionKey(""), 0, 0)) }

View File

@@ -7,11 +7,11 @@ package io.airbyte.integrations.destination.snowflake.write
import io.airbyte.cdk.load.test.util.DestinationCleaner import io.airbyte.cdk.load.test.util.DestinationCleaner
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier
import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory
import io.airbyte.integrations.destination.snowflake.sql.STAGE_NAME_PREFIX import io.airbyte.integrations.destination.snowflake.sql.STAGE_NAME_PREFIX
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
import io.airbyte.integrations.destination.snowflake.sql.quote import io.airbyte.integrations.destination.snowflake.sql.quote
import java.nio.file.Files import java.nio.file.Files
import java.sql.Connection import java.sql.Connection

View File

@@ -12,16 +12,22 @@ import io.airbyte.cdk.load.data.ObjectValue
import io.airbyte.cdk.load.data.json.toAirbyteValue import io.airbyte.cdk.load.data.json.toAirbyteValue
import io.airbyte.cdk.load.message.Meta import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator
import io.airbyte.cdk.load.test.util.DestinationDataDumper import io.airbyte.cdk.load.test.util.DestinationDataDumper
import io.airbyte.cdk.load.test.util.OutputRecord import io.airbyte.cdk.load.test.util.OutputRecord
import io.airbyte.cdk.load.util.UUIDGenerator
import io.airbyte.cdk.load.util.deserializeToNode import io.airbyte.cdk.load.util.deserializeToNode
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
import io.airbyte.integrations.destination.snowflake.db.SnowflakeFinalTableNameGenerator import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.schema.SnowflakeTableSchemaMapper
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
import io.airbyte.integrations.destination.snowflake.sql.sqlEscape import io.airbyte.integrations.destination.snowflake.sql.sqlEscape
import java.math.BigDecimal import java.math.BigDecimal
import java.sql.Date
import java.sql.Time
import java.sql.Timestamp
import net.snowflake.client.jdbc.SnowflakeTimestampWithTimezone import net.snowflake.client.jdbc.SnowflakeTimestampWithTimezone
private val AIRBYTE_META_COLUMNS = Meta.COLUMN_NAMES + setOf(CDC_DELETED_AT_COLUMN) private val AIRBYTE_META_COLUMNS = Meta.COLUMN_NAMES + setOf(CDC_DELETED_AT_COLUMN)
@@ -34,8 +40,14 @@ class SnowflakeDataDumper(
stream: DestinationStream stream: DestinationStream
): List<OutputRecord> { ): List<OutputRecord> {
val config = configProvider(spec) val config = configProvider(spec)
val sqlUtils = SnowflakeSqlNameUtils(config) val snowflakeFinalTableNameGenerator =
val snowflakeFinalTableNameGenerator = SnowflakeFinalTableNameGenerator(config) SnowflakeTableSchemaMapper(
config = config,
tempTableNameGenerator = DefaultTempTableNameGenerator(),
)
val snowflakeColumnManager = SnowflakeColumnManager(config)
val sqlGenerator =
SnowflakeDirectLoadSqlGenerator(UUIDGenerator(), config, snowflakeColumnManager)
val dataSource = val dataSource =
SnowflakeBeanFactory() SnowflakeBeanFactory()
.snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY") .snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY")
@@ -46,7 +58,7 @@ class SnowflakeDataDumper(
ds.connection.use { connection -> ds.connection.use { connection ->
val statement = connection.createStatement() val statement = connection.createStatement()
val tableName = val tableName =
snowflakeFinalTableNameGenerator.getTableName(stream.mappedDescriptor) snowflakeFinalTableNameGenerator.toFinalTableName(stream.mappedDescriptor)
// First check if the table exists // First check if the table exists
val tableExistsQuery = val tableExistsQuery =
@@ -69,7 +81,7 @@ class SnowflakeDataDumper(
val resultSet = val resultSet =
statement.executeQuery( statement.executeQuery(
"SELECT * FROM ${sqlUtils.fullyQualifiedName(tableName)}" "SELECT * FROM ${sqlGenerator.fullyQualifiedName(tableName)}"
) )
while (resultSet.next()) { while (resultSet.next()) {
@@ -143,10 +155,10 @@ class SnowflakeDataDumper(
private fun convertValue(value: Any?): Any? = private fun convertValue(value: Any?): Any? =
when (value) { when (value) {
is BigDecimal -> value.toBigInteger() is BigDecimal -> value.toBigInteger()
is java.sql.Date -> value.toLocalDate() is Date -> value.toLocalDate()
is SnowflakeTimestampWithTimezone -> value.toZonedDateTime() is SnowflakeTimestampWithTimezone -> value.toZonedDateTime()
is java.sql.Time -> value.toLocalTime() is Time -> value.toLocalTime()
is java.sql.Timestamp -> value.toLocalDateTime() is Timestamp -> value.toLocalDateTime()
else -> value else -> value
} }
} }

View File

@@ -15,7 +15,7 @@ import io.airbyte.cdk.load.dataflow.transform.ValidationResult
import io.airbyte.cdk.load.message.Meta import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.test.util.ExpectedRecordMapper import io.airbyte.cdk.load.test.util.ExpectedRecordMapper
import io.airbyte.cdk.load.test.util.OutputRecord import io.airbyte.cdk.load.test.util.OutputRecord
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.write.transform.SnowflakeValueCoercer import io.airbyte.integrations.destination.snowflake.write.transform.SnowflakeValueCoercer
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Change import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Change

View File

@@ -5,7 +5,7 @@
package io.airbyte.integrations.destination.snowflake.write package io.airbyte.integrations.destination.snowflake.write
import io.airbyte.cdk.load.test.util.NameMapper import io.airbyte.cdk.load.test.util.NameMapper
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
class SnowflakeNameMapper : NameMapper { class SnowflakeNameMapper : NameMapper {
override fun mapFieldName(path: List<String>): List<String> = override fun mapFieldName(path: List<String>): List<String> =

View File

@@ -9,13 +9,16 @@ import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.data.StringValue import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.data.json.toAirbyteValue import io.airbyte.cdk.load.data.json.toAirbyteValue
import io.airbyte.cdk.load.message.Meta import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator
import io.airbyte.cdk.load.test.util.DestinationDataDumper import io.airbyte.cdk.load.test.util.DestinationDataDumper
import io.airbyte.cdk.load.test.util.OutputRecord import io.airbyte.cdk.load.test.util.OutputRecord
import io.airbyte.cdk.load.util.UUIDGenerator
import io.airbyte.cdk.load.util.deserializeToNode import io.airbyte.cdk.load.util.deserializeToNode
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
import io.airbyte.integrations.destination.snowflake.db.SnowflakeFinalTableNameGenerator import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeTableSchemaMapper
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
class SnowflakeRawDataDumper( class SnowflakeRawDataDumper(
private val configProvider: (ConfigurationSpecification) -> SnowflakeConfiguration private val configProvider: (ConfigurationSpecification) -> SnowflakeConfiguration
@@ -27,8 +30,18 @@ class SnowflakeRawDataDumper(
val output = mutableListOf<OutputRecord>() val output = mutableListOf<OutputRecord>()
val config = configProvider(spec) val config = configProvider(spec)
val sqlUtils = SnowflakeSqlNameUtils(config) val snowflakeColumnManager = SnowflakeColumnManager(config)
val snowflakeFinalTableNameGenerator = SnowflakeFinalTableNameGenerator(config) val sqlGenerator =
SnowflakeDirectLoadSqlGenerator(
UUIDGenerator(),
config,
snowflakeColumnManager,
)
val snowflakeFinalTableNameGenerator =
SnowflakeTableSchemaMapper(
config = config,
tempTableNameGenerator = DefaultTempTableNameGenerator(),
)
val dataSource = val dataSource =
SnowflakeBeanFactory() SnowflakeBeanFactory()
.snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY") .snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY")
@@ -37,11 +50,11 @@ class SnowflakeRawDataDumper(
ds.connection.use { connection -> ds.connection.use { connection ->
val statement = connection.createStatement() val statement = connection.createStatement()
val tableName = val tableName =
snowflakeFinalTableNameGenerator.getTableName(stream.mappedDescriptor) snowflakeFinalTableNameGenerator.toFinalTableName(stream.mappedDescriptor)
val resultSet = val resultSet =
statement.executeQuery( statement.executeQuery(
"SELECT * FROM ${sqlUtils.fullyQualifiedName(tableName)}" "SELECT * FROM ${sqlGenerator.fullyQualifiedName(tableName)}"
) )
while (resultSet.next()) { while (resultSet.next()) {

View File

@@ -6,7 +6,7 @@ package io.airbyte.integrations.destination.snowflake
import com.zaxxer.hikari.HikariConfig import com.zaxxer.hikari.HikariConfig
import com.zaxxer.hikari.HikariDataSource import com.zaxxer.hikari.HikariDataSource
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode
import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration

View File

@@ -5,11 +5,9 @@
package io.airbyte.integrations.destination.snowflake.check package io.airbyte.integrations.destination.snowflake.check
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.DEFAULT_COLUMNS
import io.airbyte.integrations.destination.snowflake.sql.RAW_DATA_COLUMN
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.mockk.coEvery import io.mockk.coEvery
import io.mockk.coVerify import io.mockk.coVerify
import io.mockk.every import io.mockk.every
@@ -23,47 +21,31 @@ internal class SnowflakeCheckerTest {
@ParameterizedTest @ParameterizedTest
@ValueSource(booleans = [true, false]) @ValueSource(booleans = [true, false])
fun testSuccessfulCheck(isLegacyRawTablesOnly: Boolean) { fun testSuccessfulCheck(isLegacyRawTablesOnly: Boolean) {
val defaultColumnsMap =
if (isLegacyRawTablesOnly) {
linkedMapOf<String, String>().also { map ->
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
map[it.columnName] = it.columnType
}
}
} else {
linkedMapOf<String, String>().also { map ->
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
map[it.columnName.toSnowflakeCompatibleName()] = it.columnType
}
}
}
val defaultColumns = defaultColumnsMap.keys.toMutableList()
val snowflakeAirbyteClient: SnowflakeAirbyteClient = val snowflakeAirbyteClient: SnowflakeAirbyteClient =
mockk(relaxed = true) { mockk(relaxed = true) { coEvery { countTable(any()) } returns 1L }
coEvery { countTable(any()) } returns 1L
coEvery { describeTable(any()) } returns defaultColumnsMap
}
val testSchema = "test-schema" val testSchema = "test-schema"
val snowflakeConfiguration: SnowflakeConfiguration = mockk { val snowflakeConfiguration: SnowflakeConfiguration = mockk {
every { schema } returns testSchema every { schema } returns testSchema
every { legacyRawTablesOnly } returns isLegacyRawTablesOnly every { legacyRawTablesOnly } returns isLegacyRawTablesOnly
} }
val snowflakeColumnUtils =
mockk<SnowflakeColumnUtils>(relaxUnitFun = true) { val columnManager: SnowflakeColumnManager = SnowflakeColumnManager(snowflakeConfiguration)
every { getFormattedDefaultColumnNames(any()) } returns defaultColumns
}
val checker = val checker =
SnowflakeChecker( SnowflakeChecker(
snowflakeAirbyteClient = snowflakeAirbyteClient, snowflakeAirbyteClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration, snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils, columnManager = columnManager,
) )
checker.check() checker.check()
coVerify(exactly = 1) { coVerify(exactly = 1) {
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName()) if (isLegacyRawTablesOnly) {
snowflakeAirbyteClient.createNamespace(testSchema)
} else {
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
}
} }
coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) } coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) }
coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) } coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) }
@@ -72,48 +54,32 @@ internal class SnowflakeCheckerTest {
@ParameterizedTest @ParameterizedTest
@ValueSource(booleans = [true, false]) @ValueSource(booleans = [true, false])
fun testUnsuccessfulCheck(isLegacyRawTablesOnly: Boolean) { fun testUnsuccessfulCheck(isLegacyRawTablesOnly: Boolean) {
val defaultColumnsMap =
if (isLegacyRawTablesOnly) {
linkedMapOf<String, String>().also { map ->
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
map[it.columnName] = it.columnType
}
}
} else {
linkedMapOf<String, String>().also { map ->
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
map[it.columnName.toSnowflakeCompatibleName()] = it.columnType
}
}
}
val defaultColumns = defaultColumnsMap.keys.toMutableList()
val snowflakeAirbyteClient: SnowflakeAirbyteClient = val snowflakeAirbyteClient: SnowflakeAirbyteClient =
mockk(relaxed = true) { mockk(relaxed = true) { coEvery { countTable(any()) } returns 0L }
coEvery { countTable(any()) } returns 0L
coEvery { describeTable(any()) } returns defaultColumnsMap
}
val testSchema = "test-schema" val testSchema = "test-schema"
val snowflakeConfiguration: SnowflakeConfiguration = mockk { val snowflakeConfiguration: SnowflakeConfiguration = mockk {
every { schema } returns testSchema every { schema } returns testSchema
every { legacyRawTablesOnly } returns isLegacyRawTablesOnly every { legacyRawTablesOnly } returns isLegacyRawTablesOnly
} }
val snowflakeColumnUtils =
mockk<SnowflakeColumnUtils>(relaxUnitFun = true) { val columnManager: SnowflakeColumnManager = SnowflakeColumnManager(snowflakeConfiguration)
every { getFormattedDefaultColumnNames(any()) } returns defaultColumns
}
val checker = val checker =
SnowflakeChecker( SnowflakeChecker(
snowflakeAirbyteClient = snowflakeAirbyteClient, snowflakeAirbyteClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration, snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils, columnManager = columnManager,
) )
assertThrows<IllegalArgumentException> { checker.check() } assertThrows<IllegalArgumentException> { checker.check() }
coVerify(exactly = 1) { coVerify(exactly = 1) {
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName()) if (isLegacyRawTablesOnly) {
snowflakeAirbyteClient.createNamespace(testSchema)
} else {
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
}
} }
coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) } coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) }
coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) } coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) }

View File

@@ -6,24 +6,16 @@ package io.airbyte.integrations.destination.snowflake.client
import io.airbyte.cdk.ConfigErrorException import io.airbyte.cdk.ConfigErrorException
import io.airbyte.cdk.load.command.DestinationStream import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.command.NamespaceMapper
import io.airbyte.cdk.load.command.Overwrite
import io.airbyte.cdk.load.component.ColumnType import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.config.NamespaceDefinitionType
import io.airbyte.cdk.load.data.AirbyteType
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.ColumnNameMapping import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.cdk.load.table.TableName import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.db.ColumnDefinition import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS
import io.airbyte.integrations.destination.snowflake.sql.ColumnAndType
import io.airbyte.integrations.destination.snowflake.sql.DEFAULT_COLUMNS
import io.airbyte.integrations.destination.snowflake.sql.QUOTE import io.airbyte.integrations.destination.snowflake.sql.QUOTE
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
import io.mockk.Runs import io.mockk.Runs
import io.mockk.every import io.mockk.every
@@ -49,31 +41,18 @@ internal class SnowflakeAirbyteClientTest {
private lateinit var client: SnowflakeAirbyteClient private lateinit var client: SnowflakeAirbyteClient
private lateinit var dataSource: DataSource private lateinit var dataSource: DataSource
private lateinit var sqlGenerator: SnowflakeDirectLoadSqlGenerator private lateinit var sqlGenerator: SnowflakeDirectLoadSqlGenerator
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
private lateinit var snowflakeConfiguration: SnowflakeConfiguration private lateinit var snowflakeConfiguration: SnowflakeConfiguration
private lateinit var columnManager: SnowflakeColumnManager
@BeforeEach @BeforeEach
fun setup() { fun setup() {
dataSource = mockk() dataSource = mockk()
sqlGenerator = mockk(relaxed = true) sqlGenerator = mockk(relaxed = true)
snowflakeColumnUtils =
mockk(relaxed = true) {
every { formatColumnName(any()) } answers
{
firstArg<String>().toSnowflakeCompatibleName()
}
every { getFormattedDefaultColumnNames(any()) } returns
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() }
}
snowflakeConfiguration = snowflakeConfiguration =
mockk(relaxed = true) { every { database } returns "test_database" } mockk(relaxed = true) { every { database } returns "test_database" }
columnManager = mockk(relaxed = true)
client = client =
SnowflakeAirbyteClient( SnowflakeAirbyteClient(dataSource, sqlGenerator, snowflakeConfiguration, columnManager)
dataSource,
sqlGenerator,
snowflakeColumnUtils,
snowflakeConfiguration
)
} }
@Test @Test
@@ -231,7 +210,7 @@ internal class SnowflakeAirbyteClientTest {
@Test @Test
fun testCreateTable() { fun testCreateTable() {
val columnNameMapping = mockk<ColumnNameMapping>(relaxed = true) val columnNameMapping = mockk<ColumnNameMapping>(relaxed = true)
val stream = mockk<DestinationStream>() val stream = mockk<DestinationStream>(relaxed = true)
val tableName = TableName(namespace = "namespace", name = "name") val tableName = TableName(namespace = "namespace", name = "name")
val resultSet = mockk<ResultSet>(relaxed = true) val resultSet = mockk<ResultSet>(relaxed = true)
val statement = val statement =
@@ -254,9 +233,7 @@ internal class SnowflakeAirbyteClientTest {
columnNameMapping = columnNameMapping, columnNameMapping = columnNameMapping,
replace = true, replace = true,
) )
verify(exactly = 1) { verify(exactly = 1) { sqlGenerator.createTable(tableName, any(), true) }
sqlGenerator.createTable(stream, tableName, columnNameMapping, true)
}
verify(exactly = 1) { sqlGenerator.createSnowflakeStage(tableName) } verify(exactly = 1) { sqlGenerator.createSnowflakeStage(tableName) }
verify(exactly = 2) { mockConnection.close() } verify(exactly = 2) { mockConnection.close() }
} }
@@ -288,7 +265,7 @@ internal class SnowflakeAirbyteClientTest {
targetTableName = destinationTableName, targetTableName = destinationTableName,
) )
verify(exactly = 1) { verify(exactly = 1) {
sqlGenerator.copyTable(columnNameMapping, sourceTableName, destinationTableName) sqlGenerator.copyTable(any<Set<String>>(), sourceTableName, destinationTableName)
} }
verify(exactly = 1) { mockConnection.close() } verify(exactly = 1) { mockConnection.close() }
} }
@@ -299,7 +276,7 @@ internal class SnowflakeAirbyteClientTest {
val columnNameMapping = mockk<ColumnNameMapping>(relaxed = true) val columnNameMapping = mockk<ColumnNameMapping>(relaxed = true)
val sourceTableName = TableName(namespace = "namespace", name = "source") val sourceTableName = TableName(namespace = "namespace", name = "source")
val destinationTableName = TableName(namespace = "namespace", name = "destination") val destinationTableName = TableName(namespace = "namespace", name = "destination")
val stream = mockk<DestinationStream>() val stream = mockk<DestinationStream>(relaxed = true)
val resultSet = mockk<ResultSet>(relaxed = true) val resultSet = mockk<ResultSet>(relaxed = true)
val statement = val statement =
mockk<Statement> { mockk<Statement> {
@@ -322,12 +299,7 @@ internal class SnowflakeAirbyteClientTest {
targetTableName = destinationTableName, targetTableName = destinationTableName,
) )
verify(exactly = 1) { verify(exactly = 1) {
sqlGenerator.upsertTable( sqlGenerator.upsertTable(any(), sourceTableName, destinationTableName)
stream,
columnNameMapping,
sourceTableName,
destinationTableName
)
} }
verify(exactly = 1) { mockConnection.close() } verify(exactly = 1) { mockConnection.close() }
} }
@@ -379,7 +351,7 @@ internal class SnowflakeAirbyteClientTest {
} }
every { dataSource.connection } returns mockConnection every { dataSource.connection } returns mockConnection
every { snowflakeColumnUtils.getGenerationIdColumnName() } returns generationIdColumnName every { columnManager.getGenerationIdColumnName() } returns generationIdColumnName
every { sqlGenerator.getGenerationId(tableName) } returns every { sqlGenerator.getGenerationId(tableName) } returns
"SELECT $generationIdColumnName FROM ${tableName.toPrettyString(QUOTE)}" "SELECT $generationIdColumnName FROM ${tableName.toPrettyString(QUOTE)}"
@@ -501,8 +473,8 @@ internal class SnowflakeAirbyteClientTest {
every { dataSource.connection } returns mockConnection every { dataSource.connection } returns mockConnection
runBlocking { runBlocking {
client.copyFromStage(tableName, "test.csv.gz") client.copyFromStage(tableName, "test.csv.gz", listOf())
verify(exactly = 1) { sqlGenerator.copyFromStage(tableName, "test.csv.gz") } verify(exactly = 1) { sqlGenerator.copyFromStage(tableName, "test.csv.gz", listOf()) }
verify(exactly = 1) { mockConnection.close() } verify(exactly = 1) { mockConnection.close() }
} }
} }
@@ -556,7 +528,7 @@ internal class SnowflakeAirbyteClientTest {
"COL1" andThen "COL1" andThen
COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName() andThen COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName() andThen
"COL2" "COL2"
every { resultSet.getString("type") } returns "VARCHAR(255)" andThen "NUMBER(38,0)" every { resultSet.getString("type") } returns "VARCHAR(255)" andThen "NUMBER"
every { resultSet.getString("null?") } returns "Y" andThen "N" andThen "N" every { resultSet.getString("null?") } returns "Y" andThen "N" andThen "N"
val statement = val statement =
@@ -571,6 +543,10 @@ internal class SnowflakeAirbyteClientTest {
every { dataSource.connection } returns connection every { dataSource.connection } returns connection
// Mock the columnManager to return the correct set of meta columns
every { columnManager.getMetaColumnNames() } returns
setOf(COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName())
val result = client.getColumnsFromDb(tableName) val result = client.getColumnsFromDb(tableName)
val expectedColumns = val expectedColumns =
@@ -582,81 +558,6 @@ internal class SnowflakeAirbyteClientTest {
assertEquals(expectedColumns, result) assertEquals(expectedColumns, result)
} }
@Test
fun `getColumnsFromStream should return correct column definitions`() {
val schema = mockk<AirbyteType>()
val stream =
DestinationStream(
unmappedNamespace = "test_namespace",
unmappedName = "test_stream",
importType = Overwrite,
schema = schema,
generationId = 1,
minimumGenerationId = 1,
syncId = 1,
namespaceMapper = NamespaceMapper(NamespaceDefinitionType.DESTINATION)
)
val columnNameMapping =
ColumnNameMapping(
mapOf(
"col1" to "COL1_MAPPED",
"col2" to "COL2_MAPPED",
)
)
val col1FieldType = mockk<FieldType>()
every { col1FieldType.type } returns mockk()
val col2FieldType = mockk<FieldType>()
every { col2FieldType.type } returns mockk()
every { schema.asColumns() } returns
linkedMapOf("col1" to col1FieldType, "col2" to col2FieldType)
every { snowflakeColumnUtils.toDialectType(col1FieldType.type) } returns "VARCHAR(255)"
every { snowflakeColumnUtils.toDialectType(col2FieldType.type) } returns "NUMBER(38,0)"
every { snowflakeColumnUtils.columnsAndTypes(any(), any()) } returns
listOf(ColumnAndType("COL1_MAPPED", "VARCHAR"), ColumnAndType("COL2_MAPPED", "NUMBER"))
every { snowflakeColumnUtils.formatColumnName(any(), false) } answers
{
firstArg<String>().toSnowflakeCompatibleName()
}
val result = client.getColumnsFromStream(stream, columnNameMapping)
val expectedColumns =
mapOf(
"COL1_MAPPED" to ColumnType("VARCHAR", true),
"COL2_MAPPED" to ColumnType("NUMBER", true),
)
assertEquals(expectedColumns, result)
}
@Test
fun `generateSchemaChanges should correctly identify changes`() {
val columnsInDb =
setOf(
ColumnDefinition("COL1", "VARCHAR"),
ColumnDefinition("COL2", "NUMBER"),
ColumnDefinition("COL3", "BOOLEAN")
)
val columnsInStream =
setOf(
ColumnDefinition("COL1", "VARCHAR"), // Unchanged
ColumnDefinition("COL3", "TEXT"), // Modified
ColumnDefinition("COL4", "DATE") // Added
)
val (added, deleted, modified) = client.generateSchemaChanges(columnsInDb, columnsInStream)
assertEquals(1, added.size)
assertEquals("COL4", added.first().name)
assertEquals(1, deleted.size)
assertEquals("COL2", deleted.first().name)
assertEquals(1, modified.size)
assertEquals("COL3", modified.first().name)
}
@Test @Test
fun testCreateNamespaceWithNetworkFailure() { fun testCreateNamespaceWithNetworkFailure() {
val namespace = "test_namespace" val namespace = "test_namespace"

View File

@@ -4,14 +4,19 @@
package io.airbyte.integrations.destination.snowflake.dataflow package io.airbyte.integrations.destination.snowflake.dataflow
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.command.DestinationStream.Descriptor import io.airbyte.cdk.load.command.DestinationStream.Descriptor
import io.airbyte.cdk.load.dataflow.aggregate.StoreKey import io.airbyte.cdk.load.dataflow.aggregate.StoreKey
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.table.TableName import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.write.StreamStateStore import io.airbyte.cdk.load.write.StreamStateStore
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRawRecordFormatter
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter
import io.mockk.every import io.mockk.every
import io.mockk.mockk import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertEquals
@@ -22,27 +27,33 @@ internal class SnowflakeAggregateFactoryTest {
@Test @Test
fun testCreatingAggregateWithRawBuffer() { fun testCreatingAggregateWithRawBuffer() {
val descriptor = Descriptor(namespace = "namespace", name = "name") val descriptor = Descriptor(namespace = "namespace", name = "name")
val directLoadTableExecutionConfig = val tableName =
DirectLoadTableExecutionConfig( TableName(
tableName = namespace = descriptor.namespace!!,
TableName( name = descriptor.name,
namespace = descriptor.namespace!!,
name = descriptor.name,
)
) )
val directLoadTableExecutionConfig = DirectLoadTableExecutionConfig(tableName = tableName)
val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name) val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name)
val streamStore = StreamStateStore<DirectLoadTableExecutionConfig>() val streamStore = StreamStateStore<DirectLoadTableExecutionConfig>()
streamStore.put(descriptor, directLoadTableExecutionConfig) streamStore.put(key, directLoadTableExecutionConfig)
val stream = mockk<DestinationStream>(relaxed = true)
val catalog = mockk<DestinationCatalog> { every { getStream(key) } returns stream }
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val snowflakeConfiguration = val snowflakeConfiguration =
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns true } mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns true }
val snowflakeColumnUtils = mockk<SnowflakeColumnUtils>(relaxed = true) val columnManager = SnowflakeColumnManager(snowflakeConfiguration)
val snowflakeRecordFormatter: SnowflakeRecordFormatter = SnowflakeRawRecordFormatter()
val factory = val factory =
SnowflakeAggregateFactory( SnowflakeAggregateFactory(
snowflakeClient = snowflakeClient, snowflakeClient = snowflakeClient,
streamStateStore = streamStore, streamStateStore = streamStore,
snowflakeConfiguration = snowflakeConfiguration, snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils, catalog = catalog,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
) )
val aggregate = factory.create(key) val aggregate = factory.create(key)
assertNotNull(aggregate) assertNotNull(aggregate)
@@ -52,26 +63,33 @@ internal class SnowflakeAggregateFactoryTest {
@Test @Test
fun testCreatingAggregateWithStagingBuffer() { fun testCreatingAggregateWithStagingBuffer() {
val descriptor = Descriptor(namespace = "namespace", name = "name") val descriptor = Descriptor(namespace = "namespace", name = "name")
val directLoadTableExecutionConfig = val tableName =
DirectLoadTableExecutionConfig( TableName(
tableName = namespace = descriptor.namespace!!,
TableName( name = descriptor.name,
namespace = descriptor.namespace!!,
name = descriptor.name,
)
) )
val directLoadTableExecutionConfig = DirectLoadTableExecutionConfig(tableName = tableName)
val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name) val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name)
val streamStore = StreamStateStore<DirectLoadTableExecutionConfig>() val streamStore = StreamStateStore<DirectLoadTableExecutionConfig>()
streamStore.put(descriptor, directLoadTableExecutionConfig) streamStore.put(key, directLoadTableExecutionConfig)
val stream = mockk<DestinationStream>(relaxed = true)
val catalog = mockk<DestinationCatalog> { every { getStream(key) } returns stream }
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val snowflakeConfiguration = mockk<SnowflakeConfiguration>(relaxed = true) val snowflakeConfiguration =
val snowflakeColumnUtils = mockk<SnowflakeColumnUtils>(relaxed = true) mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns false }
val columnManager = SnowflakeColumnManager(snowflakeConfiguration)
val snowflakeRecordFormatter: SnowflakeRecordFormatter = SnowflakeSchemaRecordFormatter()
val factory = val factory =
SnowflakeAggregateFactory( SnowflakeAggregateFactory(
snowflakeClient = snowflakeClient, snowflakeClient = snowflakeClient,
streamStateStore = streamStore, streamStateStore = streamStore,
snowflakeConfiguration = snowflakeConfiguration, snowflakeConfiguration = snowflakeConfiguration,
snowflakeColumnUtils = snowflakeColumnUtils, catalog = catalog,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
) )
val aggregate = factory.create(key) val aggregate = factory.create(key)
assertNotNull(aggregate) assertNotNull(aggregate)

View File

@@ -1,23 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
import io.mockk.every
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
internal class SnowflakeColumnNameGeneratorTest {
@Test
fun testGetColumnName() {
val column = "test-column"
val generator =
SnowflakeColumnNameGenerator(mockk { every { legacyRawTablesOnly } returns false })
val columnName = generator.getColumnName(column)
assertEquals(column.toSnowflakeCompatibleName(), columnName.displayName)
assertEquals(column.toSnowflakeCompatibleName(), columnName.canonicalName)
}
}

View File

@@ -1,72 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.mockk.every
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
internal class SnowflakeFinalTableNameGeneratorTest {
@Test
fun testGetTableNameWithInternalNamespace() {
val configuration =
mockk<SnowflakeConfiguration> {
every { internalTableSchema } returns "test-internal-namespace"
every { legacyRawTablesOnly } returns true
}
val generator = SnowflakeFinalTableNameGenerator(config = configuration)
val streamName = "test-stream-name"
val streamNamespace = "test-stream-namespace"
val streamDescriptor =
mockk<DestinationStream.Descriptor> {
every { namespace } returns streamNamespace
every { name } returns streamName
}
val tableName = generator.getTableName(streamDescriptor)
assertEquals("test-stream-namespace_raw__stream_test-stream-name", tableName.name)
assertEquals("test-internal-namespace", tableName.namespace)
}
@Test
fun testGetTableNameWithNamespace() {
val configuration =
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns false }
val generator = SnowflakeFinalTableNameGenerator(config = configuration)
val streamName = "test-stream-name"
val streamNamespace = "test-stream-namespace"
val streamDescriptor =
mockk<DestinationStream.Descriptor> {
every { namespace } returns streamNamespace
every { name } returns streamName
}
val tableName = generator.getTableName(streamDescriptor)
assertEquals("TEST-STREAM-NAME", tableName.name)
assertEquals("TEST-STREAM-NAMESPACE", tableName.namespace)
}
@Test
fun testGetTableNameWithDefaultNamespace() {
val defaultNamespace = "test-default-namespace"
val configuration =
mockk<SnowflakeConfiguration> {
every { schema } returns defaultNamespace
every { legacyRawTablesOnly } returns false
}
val generator = SnowflakeFinalTableNameGenerator(config = configuration)
val streamName = "test-stream-name"
val streamDescriptor =
mockk<DestinationStream.Descriptor> {
every { namespace } returns null
every { name } returns streamName
}
val tableName = generator.getTableName(streamDescriptor)
assertEquals("TEST-STREAM-NAME", tableName.name)
assertEquals("TEST-DEFAULT-NAMESPACE", tableName.namespace)
}
}

View File

@@ -1,27 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.db
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.CsvSource
internal class SnowflakeNameGeneratorsTest {
@ParameterizedTest
@CsvSource(
value =
[
"test-name,TEST-NAME",
"1-test-name,1-TEST-NAME",
"test-name!!!,TEST-NAME!!!",
"test\${name,TEST__NAME",
"test\"name,TEST\"\"NAME",
]
)
fun testToSnowflakeCompatibleName(name: String, expected: String) {
assertEquals(expected, name.toSnowflakeCompatibleName())
}
}

View File

@@ -1,358 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
import com.fasterxml.jackson.databind.JsonNode
import io.airbyte.cdk.load.data.ArrayType
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
import io.airbyte.cdk.load.data.BooleanType
import io.airbyte.cdk.load.data.DateType
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerType
import io.airbyte.cdk.load.data.NumberType
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
import io.airbyte.cdk.load.data.UnionType
import io.airbyte.cdk.load.data.UnknownType
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA
import io.airbyte.cdk.load.orchestration.db.ColumnNameGenerator
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
import io.airbyte.cdk.load.table.ColumnNameMapping
import io.airbyte.integrations.destination.snowflake.db.SnowflakeColumnNameGenerator
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.mockk.every
import io.mockk.mockk
import kotlin.collections.LinkedHashMap
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.CsvSource
internal class SnowflakeColumnUtilsTest {
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
private lateinit var snowflakeColumnNameGenerator: SnowflakeColumnNameGenerator
@BeforeEach
fun setup() {
snowflakeConfiguration = mockk(relaxed = true)
snowflakeColumnNameGenerator =
mockk(relaxed = true) {
every { getColumnName(any()) } answers
{
val displayName =
if (snowflakeConfiguration.legacyRawTablesOnly) firstArg<String>()
else firstArg<String>().toSnowflakeCompatibleName()
val canonicalName =
if (snowflakeConfiguration.legacyRawTablesOnly) firstArg<String>()
else firstArg<String>().toSnowflakeCompatibleName()
ColumnNameGenerator.ColumnName(
displayName = displayName,
canonicalName = canonicalName,
)
}
}
snowflakeColumnUtils =
SnowflakeColumnUtils(snowflakeConfiguration, snowflakeColumnNameGenerator)
}
@Test
fun testDefaultColumns() {
val expectedDefaultColumns = DEFAULT_COLUMNS
assertEquals(expectedDefaultColumns, snowflakeColumnUtils.defaultColumns())
}
@Test
fun testDefaultRawColumns() {
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val expectedDefaultColumns = DEFAULT_COLUMNS + RAW_COLUMNS
assertEquals(expectedDefaultColumns, snowflakeColumnUtils.defaultColumns())
}
@Test
fun testGetFormattedDefaultColumnNames() {
val expectedDefaultColumnNames =
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() }
val defaultColumnNames = snowflakeColumnUtils.getFormattedDefaultColumnNames()
assertEquals(expectedDefaultColumnNames, defaultColumnNames)
}
@Test
fun testGetFormattedDefaultColumnNamesQuoted() {
val expectedDefaultColumnNames =
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName().quote() }
val defaultColumnNames = snowflakeColumnUtils.getFormattedDefaultColumnNames(true)
assertEquals(expectedDefaultColumnNames, defaultColumnNames)
}
@Test
fun testGetColumnName() {
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val columnNames = snowflakeColumnUtils.getColumnNames(columnNameMapping)
val expectedColumnNames =
(DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() } + listOf("actual"))
.joinToString(",") { it.quote() }
assertEquals(expectedColumnNames, columnNames)
}
@Test
fun testGetRawColumnName() {
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val columnNames = snowflakeColumnUtils.getColumnNames(columnNameMapping)
val expectedColumnNames =
(DEFAULT_COLUMNS.map { it.columnName } + RAW_COLUMNS.map { it.columnName })
.joinToString(",") { it.quote() }
assertEquals(expectedColumnNames, columnNames)
}
@Test
fun testGetRawFormattedColumnNames() {
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val schemaColumns =
mapOf(
"column_one" to FieldType(StringType, true),
"column_two" to FieldType(IntegerType, true),
"original" to FieldType(StringType, true),
CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true)
)
val expectedColumnNames =
DEFAULT_COLUMNS.map { it.columnName.quote() } +
RAW_COLUMNS.map { it.columnName.quote() }
val columnNames =
snowflakeColumnUtils.getFormattedColumnNames(
columns = schemaColumns,
columnNameMapping = columnNameMapping
)
assertEquals(expectedColumnNames.size, columnNames.size)
assertEquals(expectedColumnNames.sorted(), columnNames.sorted())
}
@Test
fun testGetFormattedColumnNames() {
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val schemaColumns =
mapOf(
"column_one" to FieldType(StringType, true),
"column_two" to FieldType(IntegerType, true),
"original" to FieldType(StringType, true),
CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true)
)
val expectedColumnNames =
listOf(
"actual",
"column_one",
"column_two",
CDC_DELETED_AT_COLUMN,
)
.map { it.quote() } +
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName().quote() }
val columnNames =
snowflakeColumnUtils.getFormattedColumnNames(
columns = schemaColumns,
columnNameMapping = columnNameMapping
)
assertEquals(expectedColumnNames.size, columnNames.size)
assertEquals(expectedColumnNames.sorted(), columnNames.sorted())
}
@Test
fun testGetFormattedColumnNamesNoQuotes() {
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
val schemaColumns =
mapOf(
"column_one" to FieldType(StringType, true),
"column_two" to FieldType(IntegerType, true),
"original" to FieldType(StringType, true),
CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true)
)
val expectedColumnNames =
listOf(
"actual",
"column_one",
"column_two",
CDC_DELETED_AT_COLUMN,
) + DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() }
val columnNames =
snowflakeColumnUtils.getFormattedColumnNames(
columns = schemaColumns,
columnNameMapping = columnNameMapping,
quote = false
)
assertEquals(expectedColumnNames.size, columnNames.size)
assertEquals(expectedColumnNames.sorted(), columnNames.sorted())
}
@Test
fun testGeneratingRawTableColumnsAndTypesNoColumnMapping() {
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val columns =
snowflakeColumnUtils.columnsAndTypes(
columns = emptyMap(),
columnNameMapping = ColumnNameMapping(emptyMap())
)
assertEquals(DEFAULT_COLUMNS.size + RAW_COLUMNS.size, columns.size)
assertEquals(
"${SnowflakeDataType.VARIANT.typeName} $NOT_NULL",
columns.find { it.columnName == RAW_DATA_COLUMN.columnName }?.columnType
)
}
@Test
fun testGeneratingColumnsAndTypesNoColumnMapping() {
val columnName = "test-column"
val fieldType = FieldType(StringType, false)
val declaredColumns = mapOf(columnName to fieldType)
val columns =
snowflakeColumnUtils.columnsAndTypes(
columns = declaredColumns,
columnNameMapping = ColumnNameMapping(emptyMap())
)
assertEquals(DEFAULT_COLUMNS.size + 1, columns.size)
assertEquals(
"${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL",
columns.find { it.columnName == columnName }?.columnType
)
}
@Test
fun testGeneratingColumnsAndTypesWithColumnMapping() {
val columnName = "test-column"
val mappedColumnName = "mapped-column-name"
val fieldType = FieldType(StringType, false)
val declaredColumns = mapOf(columnName to fieldType)
val columnNameMapping = ColumnNameMapping(mapOf(columnName to mappedColumnName))
val columns =
snowflakeColumnUtils.columnsAndTypes(
columns = declaredColumns,
columnNameMapping = columnNameMapping
)
assertEquals(DEFAULT_COLUMNS.size + 1, columns.size)
assertEquals(
"${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL",
columns.find { it.columnName == mappedColumnName }?.columnType
)
}
@Test
fun testToDialectType() {
assertEquals(
SnowflakeDataType.BOOLEAN.typeName,
snowflakeColumnUtils.toDialectType(BooleanType)
)
assertEquals(SnowflakeDataType.DATE.typeName, snowflakeColumnUtils.toDialectType(DateType))
assertEquals(
SnowflakeDataType.NUMBER.typeName,
snowflakeColumnUtils.toDialectType(IntegerType)
)
assertEquals(
SnowflakeDataType.FLOAT.typeName,
snowflakeColumnUtils.toDialectType(NumberType)
)
assertEquals(
SnowflakeDataType.VARCHAR.typeName,
snowflakeColumnUtils.toDialectType(StringType)
)
assertEquals(
SnowflakeDataType.VARCHAR.typeName,
snowflakeColumnUtils.toDialectType(TimeTypeWithTimezone)
)
assertEquals(
SnowflakeDataType.TIME.typeName,
snowflakeColumnUtils.toDialectType(TimeTypeWithoutTimezone)
)
assertEquals(
SnowflakeDataType.TIMESTAMP_TZ.typeName,
snowflakeColumnUtils.toDialectType(TimestampTypeWithTimezone)
)
assertEquals(
SnowflakeDataType.TIMESTAMP_NTZ.typeName,
snowflakeColumnUtils.toDialectType(TimestampTypeWithoutTimezone)
)
assertEquals(
SnowflakeDataType.ARRAY.typeName,
snowflakeColumnUtils.toDialectType(ArrayType(items = FieldType(StringType, false)))
)
assertEquals(
SnowflakeDataType.ARRAY.typeName,
snowflakeColumnUtils.toDialectType(ArrayTypeWithoutSchema)
)
assertEquals(
SnowflakeDataType.OBJECT.typeName,
snowflakeColumnUtils.toDialectType(
ObjectType(
properties = LinkedHashMap(),
additionalProperties = false,
)
)
)
assertEquals(
SnowflakeDataType.OBJECT.typeName,
snowflakeColumnUtils.toDialectType(ObjectTypeWithEmptySchema)
)
assertEquals(
SnowflakeDataType.OBJECT.typeName,
snowflakeColumnUtils.toDialectType(ObjectTypeWithoutSchema)
)
assertEquals(
SnowflakeDataType.VARIANT.typeName,
snowflakeColumnUtils.toDialectType(
UnionType(
options = setOf(StringType),
isLegacyUnion = true,
)
)
)
assertEquals(
SnowflakeDataType.VARIANT.typeName,
snowflakeColumnUtils.toDialectType(
UnionType(
options = emptySet(),
isLegacyUnion = false,
)
)
)
assertEquals(
SnowflakeDataType.VARIANT.typeName,
snowflakeColumnUtils.toDialectType(UnknownType(schema = mockk<JsonNode>()))
)
}
@ParameterizedTest
@CsvSource(
value =
[
"$COLUMN_NAME_DATA, true, \"$COLUMN_NAME_DATA\"",
"some-other_Column, true, \"SOME-OTHER_COLUMN\"",
"$COLUMN_NAME_DATA, false, $COLUMN_NAME_DATA",
"some-other_Column, false, SOME-OTHER_COLUMN",
"$COLUMN_NAME_DATA, true, \"$COLUMN_NAME_DATA\"",
"some-other_Column, true, \"SOME-OTHER_COLUMN\"",
]
)
fun testFormatColumnName(columnName: String, quote: Boolean, expectedFormattedName: String) {
assertEquals(
expectedFormattedName,
snowflakeColumnUtils.formatColumnName(columnName, quote)
)
}
}

View File

@@ -1,97 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination.snowflake.sql
import io.airbyte.cdk.load.table.TableName
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.mockk.every
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
internal class SnowflakeSqlNameUtilsTest {
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
private lateinit var snowflakeSqlNameUtils: SnowflakeSqlNameUtils
@BeforeEach
fun setUp() {
snowflakeConfiguration = mockk(relaxed = true)
snowflakeSqlNameUtils =
SnowflakeSqlNameUtils(snowflakeConfiguration = snowflakeConfiguration)
}
@Test
fun testFullyQualifiedName() {
val databaseName = "test-database"
val namespace = "test-namespace"
val name = "test=name"
val tableName = TableName(namespace = namespace, name = name)
every { snowflakeConfiguration.database } returns databaseName
val expectedName =
snowflakeSqlNameUtils.combineParts(
listOf(
databaseName.toSnowflakeCompatibleName(),
tableName.namespace,
tableName.name
)
)
val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedName(tableName)
assertEquals(expectedName, fullyQualifiedName)
}
@Test
fun testFullyQualifiedNamespace() {
val databaseName = "test-database"
val namespace = "test-namespace"
every { snowflakeConfiguration.database } returns databaseName
val fullyQualifiedNamespace = snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)
assertEquals("\"TEST-DATABASE\".\"test-namespace\"", fullyQualifiedNamespace)
}
@Test
fun testFullyQualifiedStageName() {
val databaseName = "test-database"
val namespace = "test-namespace"
val name = "test=name"
val tableName = TableName(namespace = namespace, name = name)
every { snowflakeConfiguration.database } returns databaseName
val expectedName =
snowflakeSqlNameUtils.combineParts(
listOf(
databaseName.toSnowflakeCompatibleName(),
namespace,
"$STAGE_NAME_PREFIX$name"
)
)
val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName)
assertEquals(expectedName, fullyQualifiedName)
}
@Test
fun testFullyQualifiedStageNameWithEscape() {
val databaseName = "test-database"
val namespace = "test-namespace"
val name = "test=\"\"\'name"
val tableName = TableName(namespace = namespace, name = name)
every { snowflakeConfiguration.database } returns databaseName
val expectedName =
snowflakeSqlNameUtils.combineParts(
listOf(
databaseName.toSnowflakeCompatibleName(),
namespace,
"$STAGE_NAME_PREFIX${sqlEscape(name)}"
)
)
val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true)
assertEquals(expectedName, fullyQualifiedName)
}
}

View File

@@ -6,20 +6,21 @@ package io.airbyte.integrations.destination.snowflake.write
import io.airbyte.cdk.SystemErrorException import io.airbyte.cdk.SystemErrorException
import io.airbyte.cdk.load.command.Append import io.airbyte.cdk.load.command.Append
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationStream import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.orchestration.db.DatabaseInitialStatusGatherer import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.orchestration.db.TableNames import io.airbyte.cdk.load.schema.model.StreamTableSchema
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadInitialStatus import io.airbyte.cdk.load.schema.model.TableNames
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendStreamLoader import io.airbyte.cdk.load.table.DatabaseInitialStatusGatherer
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendTruncateStreamLoader import io.airbyte.cdk.load.table.TempTableNameGenerator
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableStatus import io.airbyte.cdk.load.table.directload.DirectLoadInitialStatus
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendStreamLoader
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableNameInfo import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendTruncateStreamLoader
import io.airbyte.cdk.load.table.ColumnNameMapping import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
import io.airbyte.cdk.load.table.TableName import io.airbyte.cdk.load.table.directload.DirectLoadTableStatus
import io.airbyte.cdk.load.write.StreamStateStore
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.mockk.coEvery import io.mockk.coEvery
import io.mockk.coVerify import io.mockk.coVerify
import io.mockk.every import io.mockk.every
@@ -34,55 +35,93 @@ internal class SnowflakeWriterTest {
@Test @Test
fun testSetup() { fun testSetup() {
val tableName = TableName(namespace = "test-namespace", name = "test-name") val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName) val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
val stream = mockk<DestinationStream>() val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val tableInfo = val stream =
TableNameInfo( mockk<DestinationStream> {
tableNames = tableNames, every { tableSchema } returns
columnNameMapping = ColumnNameMapping(emptyMap()) StreamTableSchema(
) tableNames = tableNames,
val catalog = TableCatalog(mapOf(stream to tableInfo)) columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
every { importType } returns Append
}
val catalog = DestinationCatalog(listOf(stream))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer = val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> { mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns emptyMap() coEvery { gatherInitialStatus() } returns
mapOf(
stream to
DirectLoadInitialStatus(
realTable = DirectLoadTableStatus(false),
tempTable = null
)
)
} }
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer = val writer =
SnowflakeWriter( SnowflakeWriter(
names = catalog, catalog = catalog,
stateGatherer = stateGatherer, stateGatherer = stateGatherer,
streamStateStore = mockk(), streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient, snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(), tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(relaxed = true), snowflakeConfiguration =
mockk(relaxed = true) {
every { internalTableSchema } returns "internal_schema"
},
) )
runBlocking { writer.setup() } runBlocking { writer.setup() }
coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) } coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) }
coVerify(exactly = 1) { stateGatherer.gatherInitialStatus(catalog) } coVerify(exactly = 1) { stateGatherer.gatherInitialStatus() }
} }
@Test @Test
fun testCreateStreamLoaderFirstGeneration() { fun testCreateStreamLoaderFirstGeneration() {
val tableName = TableName(namespace = "test-namespace", name = "test-name") val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName) val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream = val stream =
mockk<DestinationStream> { mockk<DestinationStream> {
every { minimumGenerationId } returns 0L every { minimumGenerationId } returns 0L
every { generationId } returns 0L every { generationId } returns 0L
every { importType } returns Append every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
} }
val tableInfo = val catalog = DestinationCatalog(listOf(stream))
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer = val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> { mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns coEvery { gatherInitialStatus() } returns
mapOf( mapOf(
stream to stream to
DirectLoadInitialStatus( DirectLoadInitialStatus(
@@ -93,14 +132,18 @@ internal class SnowflakeWriterTest {
} }
val tempTableNameGenerator = val tempTableNameGenerator =
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } } mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer = val writer =
SnowflakeWriter( SnowflakeWriter(
names = catalog, catalog = catalog,
stateGatherer = stateGatherer, stateGatherer = stateGatherer,
streamStateStore = mockk(), streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient, snowflakeClient = snowflakeClient,
tempTableNameGenerator = tempTableNameGenerator, tempTableNameGenerator = tempTableNameGenerator,
snowflakeConfiguration = mockk(relaxed = true), snowflakeConfiguration =
mockk(relaxed = true) {
every { internalTableSchema } returns "internal_schema"
},
) )
runBlocking { runBlocking {
@@ -113,23 +156,35 @@ internal class SnowflakeWriterTest {
@Test @Test
fun testCreateStreamLoaderNotFirstGeneration() { fun testCreateStreamLoaderNotFirstGeneration() {
val tableName = TableName(namespace = "test-namespace", name = "test-name") val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName) val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream = val stream =
mockk<DestinationStream> { mockk<DestinationStream> {
every { minimumGenerationId } returns 1L every { minimumGenerationId } returns 1L
every { generationId } returns 1L every { generationId } returns 1L
every { importType } returns Append every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
} }
val tableInfo = val catalog = DestinationCatalog(listOf(stream))
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer = val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> { mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns coEvery { gatherInitialStatus() } returns
mapOf( mapOf(
stream to stream to
DirectLoadInitialStatus( DirectLoadInitialStatus(
@@ -140,14 +195,18 @@ internal class SnowflakeWriterTest {
} }
val tempTableNameGenerator = val tempTableNameGenerator =
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } } mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer = val writer =
SnowflakeWriter( SnowflakeWriter(
names = catalog, catalog = catalog,
stateGatherer = stateGatherer, stateGatherer = stateGatherer,
streamStateStore = mockk(), streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient, snowflakeClient = snowflakeClient,
tempTableNameGenerator = tempTableNameGenerator, tempTableNameGenerator = tempTableNameGenerator,
snowflakeConfiguration = mockk(relaxed = true), snowflakeConfiguration =
mockk(relaxed = true) {
every { internalTableSchema } returns "internal_schema"
},
) )
runBlocking { runBlocking {
@@ -160,22 +219,35 @@ internal class SnowflakeWriterTest {
@Test @Test
fun testCreateStreamLoaderHybrid() { fun testCreateStreamLoaderHybrid() {
val tableName = TableName(namespace = "test-namespace", name = "test-name") val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName) val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream = val stream =
mockk<DestinationStream> { mockk<DestinationStream> {
every { minimumGenerationId } returns 1L every { minimumGenerationId } returns 1L
every { generationId } returns 2L every { generationId } returns 2L
every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
} }
val tableInfo = val catalog = DestinationCatalog(listOf(stream))
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer = val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> { mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns coEvery { gatherInitialStatus() } returns
mapOf( mapOf(
stream to stream to
DirectLoadInitialStatus( DirectLoadInitialStatus(
@@ -186,14 +258,18 @@ internal class SnowflakeWriterTest {
} }
val tempTableNameGenerator = val tempTableNameGenerator =
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } } mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer = val writer =
SnowflakeWriter( SnowflakeWriter(
names = catalog, catalog = catalog,
stateGatherer = stateGatherer, stateGatherer = stateGatherer,
streamStateStore = mockk(), streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient, snowflakeClient = snowflakeClient,
tempTableNameGenerator = tempTableNameGenerator, tempTableNameGenerator = tempTableNameGenerator,
snowflakeConfiguration = mockk(relaxed = true), snowflakeConfiguration =
mockk(relaxed = true) {
every { internalTableSchema } returns "internal_schema"
},
) )
runBlocking { runBlocking {
@@ -203,169 +279,126 @@ internal class SnowflakeWriterTest {
} }
@Test @Test
fun testSetupWithNamespaceCreationFailure() { fun testCreateStreamLoaderNamespaceLegacy() {
val tableName = TableName(namespace = "test-namespace", name = "test-name") val namespace = "test-namespace"
val tableNames = TableNames(rawTableName = null, finalTableName = tableName) val name = "test-name"
val stream = mockk<DestinationStream>() val tableName = TableName(namespace = namespace, name = name)
val tableInfo = val tempTableName = TableName(namespace = namespace, name = "${name}-temp")
TableNameInfo( val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val snowflakeClient = mockk<SnowflakeAirbyteClient>()
val stateGatherer = mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>>()
val writer =
SnowflakeWriter(
names = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(),
)
// Simulate network failure during namespace creation
coEvery {
snowflakeClient.createNamespace(tableName.namespace.toSnowflakeCompatibleName())
} throws RuntimeException("Network connection failed")
assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } }
}
@Test
fun testSetupWithInitialStatusGatheringFailure() {
val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
val stream = mockk<DestinationStream>()
val tableInfo =
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer = mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>>()
val writer =
SnowflakeWriter(
names = catalog,
stateGatherer = stateGatherer,
streamStateStore = mockk(),
snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(),
)
// Simulate failure while gathering initial status
coEvery { stateGatherer.gatherInitialStatus(catalog) } throws
RuntimeException("Failed to query table status")
assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } }
// Verify namespace creation was still attempted
coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) }
}
@Test
fun testCreateStreamLoaderWithMissingInitialStatus() {
val tableName = TableName(namespace = "test-namespace", name = "test-name")
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
val stream = val stream =
mockk<DestinationStream> { mockk<DestinationStream> {
every { minimumGenerationId } returns 0L every { minimumGenerationId } returns 0L
every { generationId } returns 0L every { generationId } returns 0L
every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
} }
val missingStream = val catalog = DestinationCatalog(listOf(stream))
mockk<DestinationStream> {
every { minimumGenerationId } returns 0L
every { generationId } returns 0L
}
val tableInfo =
TableNameInfo(
tableNames = tableNames,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream to tableInfo))
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer = val stateGatherer =
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> { mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
coEvery { gatherInitialStatus(catalog) } returns coEvery { gatherInitialStatus() } returns
mapOf( mapOf(
stream to stream to
DirectLoadInitialStatus( DirectLoadInitialStatus(
realTable = DirectLoadTableStatus(false), realTable = DirectLoadTableStatus(false),
tempTable = null, tempTable = null
) )
) )
} }
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val writer = val writer =
SnowflakeWriter( SnowflakeWriter(
names = catalog, catalog = catalog,
stateGatherer = stateGatherer, stateGatherer = stateGatherer,
streamStateStore = mockk(), streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient, snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(), tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(relaxed = true), snowflakeConfiguration =
mockk(relaxed = true) {
every { legacyRawTablesOnly } returns true
every { internalTableSchema } returns "internal_schema"
},
) )
runBlocking { runBlocking { writer.setup() }
writer.setup()
// Try to create loader for a stream that wasn't in initial status coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) }
assertThrows(NullPointerException::class.java) { }
writer.createStreamLoader(missingStream)
@Test
fun testCreateStreamLoaderNamespaceNonLegacy() {
val namespace = "test-namespace"
val name = "test-name"
val tableName = TableName(namespace = namespace, name = name)
val tempTableName = TableName(namespace = namespace, name = "${name}-temp")
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
val stream =
mockk<DestinationStream> {
every { minimumGenerationId } returns 0L
every { generationId } returns 0L
every { importType } returns Append
every { tableSchema } returns
StreamTableSchema(
tableNames = tableNames,
columnSchema =
ColumnSchema(
inputToFinalColumnNames = emptyMap(),
finalSchema = emptyMap(),
inputSchema = emptyMap()
),
importType = Append
)
every { mappedDescriptor } returns
DestinationStream.Descriptor(
namespace = tableName.namespace,
name = tableName.name
)
} }
} val catalog = DestinationCatalog(listOf(stream))
} val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val stateGatherer =
@Test mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
fun testCreateStreamLoaderWithNullFinalTableName() { coEvery { gatherInitialStatus() } returns
// TableNames constructor throws IllegalStateException when both names are null mapOf(
assertThrows(IllegalStateException::class.java) { stream to
TableNames(rawTableName = null, finalTableName = null) DirectLoadInitialStatus(
} realTable = DirectLoadTableStatus(false),
} tempTable = null
)
@Test )
fun testSetupWithMultipleNamespaceFailuresPartial() { }
val tableName1 = TableName(namespace = "namespace1", name = "table1") val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
val tableName2 = TableName(namespace = "namespace2", name = "table2")
val tableNames1 = TableNames(rawTableName = null, finalTableName = tableName1)
val tableNames2 = TableNames(rawTableName = null, finalTableName = tableName2)
val stream1 = mockk<DestinationStream>()
val stream2 = mockk<DestinationStream>()
val tableInfo1 =
TableNameInfo(
tableNames = tableNames1,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val tableInfo2 =
TableNameInfo(
tableNames = tableNames2,
columnNameMapping = ColumnNameMapping(emptyMap())
)
val catalog = TableCatalog(mapOf(stream1 to tableInfo1, stream2 to tableInfo2))
val snowflakeClient = mockk<SnowflakeAirbyteClient>()
val stateGatherer = mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>>()
val writer = val writer =
SnowflakeWriter( SnowflakeWriter(
names = catalog, catalog = catalog,
stateGatherer = stateGatherer, stateGatherer = stateGatherer,
streamStateStore = mockk(), streamStateStore = streamStateStore,
snowflakeClient = snowflakeClient, snowflakeClient = snowflakeClient,
tempTableNameGenerator = mockk(), tempTableNameGenerator = mockk(),
snowflakeConfiguration = mockk(), snowflakeConfiguration =
mockk(relaxed = true) {
every { legacyRawTablesOnly } returns false
every { internalTableSchema } returns "internal_schema"
},
) )
// First namespace succeeds, second fails (namespaces are uppercased by runBlocking { writer.setup() }
// toSnowflakeCompatibleName)
coEvery { snowflakeClient.createNamespace("namespace1") } returns Unit
coEvery { snowflakeClient.createNamespace("namespace2") } throws
RuntimeException("Connection timeout")
assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } } coVerify(exactly = 1) { snowflakeClient.createNamespace(namespace) }
// Verify both namespace creations were attempted
coVerify(exactly = 1) { snowflakeClient.createNamespace("namespace1") }
coVerify(exactly = 1) { snowflakeClient.createNamespace("namespace2") }
} }
} }

View File

@@ -4,218 +4,220 @@
package io.airbyte.integrations.destination.snowflake.write.load package io.airbyte.integrations.destination.snowflake.write.load
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.data.AirbyteValue import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerValue import io.airbyte.cdk.load.data.IntegerValue
import io.airbyte.cdk.load.data.NullValue import io.airbyte.cdk.load.data.NullValue
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.StringValue import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.message.Meta import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.table.TableName import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.cdk.load.schema.model.TableName
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.mockk.coVerify import io.mockk.coVerify
import io.mockk.every import io.mockk.every
import io.mockk.mockk import io.mockk.mockk
import java.io.BufferedReader import java.io.BufferedReader
import java.io.File
import java.io.InputStreamReader import java.io.InputStreamReader
import java.util.zip.GZIPInputStream import java.util.zip.GZIPInputStream
import kotlin.io.path.exists import kotlin.io.path.exists
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNotNull
import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
internal class SnowflakeInsertBufferTest { internal class SnowflakeInsertBufferTest {
private lateinit var snowflakeConfiguration: SnowflakeConfiguration private lateinit var snowflakeConfiguration: SnowflakeConfiguration
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils private lateinit var columnManager: SnowflakeColumnManager
private lateinit var columnSchema: ColumnSchema
private lateinit var snowflakeRecordFormatter: SnowflakeRecordFormatter
@BeforeEach @BeforeEach
fun setUp() { fun setUp() {
snowflakeConfiguration = mockk(relaxed = true) snowflakeConfiguration = mockk(relaxed = true)
snowflakeColumnUtils = mockk(relaxed = true) snowflakeRecordFormatter = SnowflakeSchemaRecordFormatter()
columnManager =
mockk(relaxed = true) {
every { getMetaColumns() } returns
linkedMapOf(
"_AIRBYTE_RAW_ID" to ColumnType("VARCHAR", false),
"_AIRBYTE_EXTRACTED_AT" to ColumnType("TIMESTAMP_TZ", false),
"_AIRBYTE_META" to ColumnType("VARIANT", false),
"_AIRBYTE_GENERATION_ID" to ColumnType("NUMBER", true)
)
every { getTableColumnNames(any()) } returns
listOf(
"_AIRBYTE_RAW_ID",
"_AIRBYTE_EXTRACTED_AT",
"_AIRBYTE_META",
"_AIRBYTE_GENERATION_ID",
"columnName"
)
}
} }
@Test @Test
fun testAccumulate() { fun testAccumulate() {
val tableName = mockk<TableName>(relaxed = true) val tableName = TableName(namespace = "test", name = "table")
val column = "columnName" val column = "columnName"
val columns = linkedMapOf(column to "NUMBER(38,0)") columnSchema =
ColumnSchema(
inputToFinalColumnNames = mapOf(column to column.uppercase()),
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord(column) val record = createRecord(column)
val buffer = val buffer =
SnowflakeInsertBuffer( SnowflakeInsertBuffer(
tableName = tableName, tableName = tableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient, snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration, snowflakeConfiguration = snowflakeConfiguration,
flushLimit = 1, columnSchema = columnSchema,
snowflakeColumnUtils = snowflakeColumnUtils, columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
) )
assertEquals(0, buffer.recordCount)
buffer.accumulate(record) runBlocking { buffer.accumulate(record) }
assertEquals(true, buffer.csvFilePath?.exists())
assertEquals(1, buffer.recordCount) assertEquals(1, buffer.recordCount)
} }
@Test @Test
fun testAccumulateRaw() { fun testFlushToStaging() {
val tableName = mockk<TableName>(relaxed = true) val tableName = TableName(namespace = "test", name = "table")
val column = "columnName" val column = "columnName"
val columns = linkedMapOf(column to "NUMBER(38,0)") columnSchema =
ColumnSchema(
inputToFinalColumnNames = mapOf(column to column.uppercase()),
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord(column) val record = createRecord(column)
val buffer = val buffer =
SnowflakeInsertBuffer( SnowflakeInsertBuffer(
tableName = tableName, tableName = tableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient, snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration, snowflakeConfiguration = snowflakeConfiguration,
columnSchema = columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
flushLimit = 1, flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
) )
val expectedColumnNames =
every { snowflakeConfiguration.legacyRawTablesOnly } returns true listOf(
"_AIRBYTE_RAW_ID",
buffer.accumulate(record) "_AIRBYTE_EXTRACTED_AT",
"_AIRBYTE_META",
assertEquals(true, buffer.csvFilePath?.exists()) "_AIRBYTE_GENERATION_ID",
assertEquals(1, buffer.recordCount) "columnName"
}
@Test
fun testFlush() {
val tableName = mockk<TableName>(relaxed = true)
val column = "columnName"
val columns = linkedMapOf(column to "NUMBER(38,0)")
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord(column)
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
) )
runBlocking { runBlocking {
buffer.accumulate(record) buffer.accumulate(record)
buffer.flush() buffer.flush()
} coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
coVerify(exactly = 1) {
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) } snowflakeAirbyteClient.copyFromStage(tableName, any(), expectedColumnNames)
coVerify(exactly = 1) { }
snowflakeAirbyteClient.copyFromStage(
tableName,
match { it.endsWith("$CSV_FILE_EXTENSION$FILE_SUFFIX") }
)
} }
} }
@Test @Test
fun testFlushRaw() { fun testFlushToNoStaging() {
val tableName = mockk<TableName>(relaxed = true) val tableName = TableName(namespace = "test", name = "table")
val column = "columnName" val column = "columnName"
val columns = linkedMapOf(column to "NUMBER(38,0)") columnSchema =
ColumnSchema(
inputToFinalColumnNames = mapOf(column to column.uppercase()),
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
val record = createRecord(column)
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
columnSchema = columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
flushLimit = 1,
)
val expectedColumnNames =
listOf(
"_AIRBYTE_RAW_ID",
"_AIRBYTE_EXTRACTED_AT",
"_AIRBYTE_META",
"_AIRBYTE_GENERATION_ID",
"columnName"
)
runBlocking {
buffer.accumulate(record)
buffer.flush()
// In legacy raw mode, it still uses staging
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
coVerify(exactly = 1) {
snowflakeAirbyteClient.copyFromStage(tableName, any(), expectedColumnNames)
}
}
}
@Test
fun testFileCreation() {
val tableName = TableName(namespace = "test", name = "table")
val column = "columnName"
columnSchema =
ColumnSchema(
inputToFinalColumnNames = mapOf(column to column.uppercase()),
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true) val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord(column) val record = createRecord(column)
val buffer = val buffer =
SnowflakeInsertBuffer( SnowflakeInsertBuffer(
tableName = tableName, tableName = tableName,
columns = columns,
snowflakeClient = snowflakeAirbyteClient, snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration, snowflakeConfiguration = snowflakeConfiguration,
columnSchema = columnSchema,
columnManager = columnManager,
snowflakeRecordFormatter = snowflakeRecordFormatter,
flushLimit = 1, flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
) )
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
runBlocking { runBlocking {
buffer.accumulate(record) buffer.accumulate(record)
buffer.flush() // The csvFilePath is internal, we can access it for testing
} val filepath = buffer.csvFilePath
assertNotNull(filepath)
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) } val file = filepath!!.toFile()
coVerify(exactly = 1) { assert(file.exists())
snowflakeAirbyteClient.copyFromStage( // Close the writer to ensure all data is flushed
tableName,
match { it.endsWith("$CSV_FILE_EXTENSION$FILE_SUFFIX") }
)
}
}
@Test
fun testMissingFields() {
val tableName = mockk<TableName>(relaxed = true)
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
val record = createRecord("COLUMN1")
val buffer =
SnowflakeInsertBuffer(
tableName = tableName,
columns = linkedMapOf("COLUMN1" to "NUMBER(38,0)", "COLUMN2" to "NUMBER(38,0)"),
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
)
runBlocking {
buffer.accumulate(record)
buffer.csvWriter?.flush()
buffer.csvWriter?.close() buffer.csvWriter?.close()
assertEquals( val lines = mutableListOf<String>()
"test-value$CSV_FIELD_SEPARATOR$CSV_LINE_DELIMITER", GZIPInputStream(file.inputStream()).use { gzip ->
readFromCsvFile(buffer.csvFilePath!!.toFile()) BufferedReader(InputStreamReader(gzip)).use { bufferedReader ->
) bufferedReader.forEachLine { line -> lines.add(line) }
}
}
assertEquals(1, lines.size)
file.delete()
} }
} }
@Test private fun createRecord(column: String): Map<String, AirbyteValue> {
fun testMissingFieldsRaw() { return mapOf(
val tableName = mockk<TableName>(relaxed = true) column to IntegerValue(value = 42),
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true) Meta.COLUMN_NAME_AB_GENERATION_ID to NullValue,
val record = createRecord("COLUMN1") Meta.COLUMN_NAME_AB_RAW_ID to StringValue("raw-id-1"),
val buffer = Meta.COLUMN_NAME_AB_EXTRACTED_AT to IntegerValue(1234567890),
SnowflakeInsertBuffer( Meta.COLUMN_NAME_AB_META to StringValue("meta-data-foo"),
tableName = tableName,
columns = linkedMapOf("COLUMN1" to "NUMBER(38,0)", "COLUMN2" to "NUMBER(38,0)"),
snowflakeClient = snowflakeAirbyteClient,
snowflakeConfiguration = snowflakeConfiguration,
flushLimit = 1,
snowflakeColumnUtils = snowflakeColumnUtils,
)
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
runBlocking {
buffer.accumulate(record)
buffer.csvWriter?.flush()
buffer.csvWriter?.close()
assertEquals(
"test-value$CSV_FIELD_SEPARATOR$CSV_LINE_DELIMITER",
readFromCsvFile(buffer.csvFilePath!!.toFile())
)
}
}
private fun readFromCsvFile(file: File) =
GZIPInputStream(file.inputStream()).use { input ->
val reader = BufferedReader(InputStreamReader(input))
reader.readText()
}
private fun createRecord(columnName: String) =
mapOf(
columnName to AirbyteValue.from("test-value"),
Meta.COLUMN_NAME_AB_EXTRACTED_AT to IntegerValue(System.currentTimeMillis()),
Meta.COLUMN_NAME_AB_RAW_ID to StringValue("raw-id"),
Meta.COLUMN_NAME_AB_GENERATION_ID to IntegerValue(1223),
Meta.COLUMN_NAME_AB_META to StringValue("{\"changes\":[],\"syncId\":43}"),
"${columnName}Null" to NullValue
) )
}
} }

View File

@@ -15,12 +15,8 @@ import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.mockk.every
import io.mockk.mockk
import kotlin.collections.plus
import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
private val AIRBYTE_COLUMN_TYPES_MAP = private val AIRBYTE_COLUMN_TYPES_MAP =
@@ -58,28 +54,16 @@ private fun createExpected(
internal class SnowflakeRawRecordFormatterTest { internal class SnowflakeRawRecordFormatterTest {
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
@BeforeEach
fun setup() {
snowflakeColumnUtils = mockk {
every { getFormattedDefaultColumnNames(any()) } returns
AIRBYTE_COLUMN_TYPES_MAP.keys.toList()
}
}
@Test @Test
fun testFormatting() { fun testFormatting() {
val columnName = "test-column-name" val columnName = "test-column-name"
val columnValue = "test-column-value" val columnValue = "test-column-value"
val columns = AIRBYTE_COLUMN_TYPES_MAP val columns = AIRBYTE_COLUMN_TYPES_MAP
val record = createRecord(columnName = columnName, columnValue = columnValue) val record = createRecord(columnName = columnName, columnValue = columnValue)
val formatter = val formatter = SnowflakeRawRecordFormatter()
SnowflakeRawRecordFormatter( // RawRecordFormatter doesn't use columnSchema but still needs one per interface
columns = AIRBYTE_COLUMN_TYPES_MAP, val dummyColumnSchema = ColumnSchema(emptyMap(), emptyMap(), emptyMap())
snowflakeColumnUtils = snowflakeColumnUtils val formattedValue = formatter.format(record, dummyColumnSchema)
)
val formattedValue = formatter.format(record)
val expectedValue = val expectedValue =
createExpected( createExpected(
record = record, record = record,
@@ -93,33 +77,28 @@ internal class SnowflakeRawRecordFormatterTest {
fun testFormattingMigratedFromPreviousVersion() { fun testFormattingMigratedFromPreviousVersion() {
val columnName = "test-column-name" val columnName = "test-column-name"
val columnValue = "test-column-value" val columnValue = "test-column-value"
val columnsMap =
linkedMapOf(
COLUMN_NAME_AB_EXTRACTED_AT to "TIMESTAMP_TZ(9)",
COLUMN_NAME_AB_LOADED_AT to "TIMESTAMP_TZ(9)",
COLUMN_NAME_AB_META to "VARIANT",
COLUMN_NAME_DATA to "VARIANT",
COLUMN_NAME_AB_RAW_ID to "VARCHAR(16777216)",
COLUMN_NAME_AB_GENERATION_ID to "NUMBER(38,0)",
)
val record = createRecord(columnName = columnName, columnValue = columnValue) val record = createRecord(columnName = columnName, columnValue = columnValue)
val formatter = val formatter = SnowflakeRawRecordFormatter()
SnowflakeRawRecordFormatter( // RawRecordFormatter doesn't use columnSchema but still needs one per interface
columns = columnsMap, val dummyColumnSchema = ColumnSchema(emptyMap(), emptyMap(), emptyMap())
snowflakeColumnUtils = snowflakeColumnUtils val formattedValue = formatter.format(record, dummyColumnSchema)
)
val formattedValue = formatter.format(record) // The formatter outputs in a fixed order regardless of input column order:
// 1. AB_RAW_ID
// 2. AB_EXTRACTED_AT
// 3. AB_META
// 4. AB_GENERATION_ID
// 5. AB_LOADED_AT
// 6. DATA (JSON with remaining columns)
val expectedValue = val expectedValue =
createExpected( listOf(
record = record, record[COLUMN_NAME_AB_RAW_ID]!!.toCsvValue(),
columns = columnsMap, record[COLUMN_NAME_AB_EXTRACTED_AT]!!.toCsvValue(),
airbyteColumns = columnsMap.keys.toList(), record[COLUMN_NAME_AB_META]!!.toCsvValue(),
) record[COLUMN_NAME_AB_GENERATION_ID]!!.toCsvValue(),
.toMutableList() record[COLUMN_NAME_AB_LOADED_AT]!!.toCsvValue(),
expectedValue.add( "{\"$columnName\":\"$columnValue\"}"
columnsMap.keys.indexOf(COLUMN_NAME_DATA), )
"{\"$columnName\":\"$columnValue\"}"
)
assertEquals(expectedValue, formattedValue) assertEquals(expectedValue, formattedValue)
} }
} }

View File

@@ -4,66 +4,58 @@
package io.airbyte.integrations.destination.snowflake.write.load package io.airbyte.integrations.destination.snowflake.write.load
import io.airbyte.cdk.load.component.ColumnType
import io.airbyte.cdk.load.data.AirbyteValue import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerValue import io.airbyte.cdk.load.data.IntegerValue
import io.airbyte.cdk.load.data.NullValue import io.airbyte.cdk.load.data.NullValue
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.StringValue import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.data.csv.toCsvValue import io.airbyte.cdk.load.data.csv.toCsvValue
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName import io.airbyte.cdk.load.schema.model.ColumnSchema
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
import io.mockk.every
import io.mockk.mockk
import java.util.AbstractMap
import kotlin.collections.component1
import kotlin.collections.component2
import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
private val AIRBYTE_COLUMN_TYPES_MAP =
linkedMapOf(
COLUMN_NAME_AB_RAW_ID to "VARCHAR(16777216)",
COLUMN_NAME_AB_EXTRACTED_AT to "TIMESTAMP_TZ(9)",
COLUMN_NAME_AB_META to "VARIANT",
COLUMN_NAME_AB_GENERATION_ID to "NUMBER(38,0)",
)
.mapKeys { it.key.toSnowflakeCompatibleName() }
internal class SnowflakeSchemaRecordFormatterTest { internal class SnowflakeSchemaRecordFormatterTest {
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils private fun createColumnSchema(userColumns: Map<String, String>): ColumnSchema {
val finalSchema = linkedMapOf<String, ColumnType>()
val inputToFinalColumnNames = mutableMapOf<String, String>()
val inputSchema = mutableMapOf<String, FieldType>()
@BeforeEach // Add user columns
fun setup() { userColumns.forEach { (name, type) ->
snowflakeColumnUtils = mockk { val finalName = name.toSnowflakeCompatibleName()
every { getFormattedDefaultColumnNames(any()) } returns finalSchema[finalName] = ColumnType(type, true)
AIRBYTE_COLUMN_TYPES_MAP.keys.toList() inputToFinalColumnNames[name] = finalName
inputSchema[name] = FieldType(StringType, nullable = true)
} }
return ColumnSchema(
inputToFinalColumnNames = inputToFinalColumnNames,
finalSchema = finalSchema,
inputSchema = inputSchema
)
} }
@Test @Test
fun testFormatting() { fun testFormatting() {
val columnName = "test-column-name" val columnName = "test-column-name"
val columnValue = "test-column-value" val columnValue = "test-column-value"
val columns = val userColumns = mapOf(columnName to "VARCHAR(16777216)")
(AIRBYTE_COLUMN_TYPES_MAP + linkedMapOf(columnName to "VARCHAR(16777216)")).mapKeys { val columnSchema = createColumnSchema(userColumns)
it.key.toSnowflakeCompatibleName()
}
val record = createRecord(columnName, columnValue) val record = createRecord(columnName, columnValue)
val formatter = val formatter = SnowflakeSchemaRecordFormatter()
SnowflakeSchemaRecordFormatter( val formattedValue = formatter.format(record, columnSchema)
columns = columns as LinkedHashMap<String, String>,
snowflakeColumnUtils = snowflakeColumnUtils
)
val formattedValue = formatter.format(record)
val expectedValue = val expectedValue =
createExpected( createExpected(
record = record, record = record,
columns = columns, columnSchema = columnSchema,
) )
assertEquals(expectedValue, formattedValue) assertEquals(expectedValue, formattedValue)
} }
@@ -72,21 +64,15 @@ internal class SnowflakeSchemaRecordFormatterTest {
fun testFormattingVariant() { fun testFormattingVariant() {
val columnName = "test-column-name" val columnName = "test-column-name"
val columnValue = "{\"test\": \"test-value\"}" val columnValue = "{\"test\": \"test-value\"}"
val columns = val userColumns = mapOf(columnName to "VARIANT")
(AIRBYTE_COLUMN_TYPES_MAP + linkedMapOf(columnName to "VARIANT")).mapKeys { val columnSchema = createColumnSchema(userColumns)
it.key.toSnowflakeCompatibleName()
}
val record = createRecord(columnName, columnValue) val record = createRecord(columnName, columnValue)
val formatter = val formatter = SnowflakeSchemaRecordFormatter()
SnowflakeSchemaRecordFormatter( val formattedValue = formatter.format(record, columnSchema)
columns = columns as LinkedHashMap<String, String>,
snowflakeColumnUtils = snowflakeColumnUtils
)
val formattedValue = formatter.format(record)
val expectedValue = val expectedValue =
createExpected( createExpected(
record = record, record = record,
columns = columns, columnSchema = columnSchema,
) )
assertEquals(expectedValue, formattedValue) assertEquals(expectedValue, formattedValue)
} }
@@ -95,23 +81,16 @@ internal class SnowflakeSchemaRecordFormatterTest {
fun testFormattingMissingColumn() { fun testFormattingMissingColumn() {
val columnName = "test-column-name" val columnName = "test-column-name"
val columnValue = "test-column-value" val columnValue = "test-column-value"
val columns = val userColumns =
AIRBYTE_COLUMN_TYPES_MAP + mapOf(columnName to "VARCHAR(16777216)", "missing-column" to "VARCHAR(16777216)")
linkedMapOf( val columnSchema = createColumnSchema(userColumns)
columnName to "VARCHAR(16777216)",
"missing-column" to "VARCHAR(16777216)"
)
val record = createRecord(columnName, columnValue) val record = createRecord(columnName, columnValue)
val formatter = val formatter = SnowflakeSchemaRecordFormatter()
SnowflakeSchemaRecordFormatter( val formattedValue = formatter.format(record, columnSchema)
columns = columns as LinkedHashMap<String, String>,
snowflakeColumnUtils = snowflakeColumnUtils
)
val formattedValue = formatter.format(record)
val expectedValue = val expectedValue =
createExpected( createExpected(
record = record, record = record,
columns = columns, columnSchema = columnSchema,
filterMissing = false, filterMissing = false,
) )
assertEquals(expectedValue, formattedValue) assertEquals(expectedValue, formattedValue)
@@ -128,16 +107,37 @@ internal class SnowflakeSchemaRecordFormatterTest {
private fun createExpected( private fun createExpected(
record: Map<String, AirbyteValue>, record: Map<String, AirbyteValue>,
columns: Map<String, String>, columnSchema: ColumnSchema,
filterMissing: Boolean = true, filterMissing: Boolean = true,
) = ): List<Any> {
record.entries val columns = columnSchema.finalSchema.keys.toList()
.associate { entry -> entry.key.toSnowflakeCompatibleName() to entry.value } val result = mutableListOf<Any>()
.map { entry -> AbstractMap.SimpleEntry(entry.key, entry.value.toCsvValue()) }
.sortedBy { entry -> // Add meta columns first in the expected order
if (columns.keys.indexOf(entry.key) > -1) columns.keys.indexOf(entry.key) result.add(record[COLUMN_NAME_AB_RAW_ID]?.toCsvValue() ?: "")
else Int.MAX_VALUE result.add(record[COLUMN_NAME_AB_EXTRACTED_AT]?.toCsvValue() ?: "")
result.add(record[COLUMN_NAME_AB_META]?.toCsvValue() ?: "")
result.add(record[COLUMN_NAME_AB_GENERATION_ID]?.toCsvValue() ?: "")
// Add user columns
val userColumns =
columns.filterNot { col ->
listOf(
COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName(),
COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName(),
COLUMN_NAME_AB_META.toSnowflakeCompatibleName(),
COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName()
)
.contains(col)
} }
.filter { (k, _) -> if (filterMissing) columns.contains(k) else true }
.map { it.value } userColumns.forEach { columnName ->
val value = record[columnName] ?: if (!filterMissing) NullValue else null
if (value != null || !filterMissing) {
result.add(value?.toCsvValue() ?: "")
}
}
return result
}
} }

Some files were not shown because too many files have changed in this diff Show More