Merge branch 'master' into devin/1765393705-fix-youtube-analytics-job-creation
This commit is contained in:
@@ -139,8 +139,8 @@ runs:
|
||||
CONNECTOR_VERSION_TAG="${{ inputs.tag-override }}"
|
||||
echo "🏷 Using provided tag override: $CONNECTOR_VERSION_TAG"
|
||||
elif [[ "${{ inputs.release-type }}" == "pre-release" ]]; then
|
||||
hash=$(git rev-parse --short=10 HEAD)
|
||||
CONNECTOR_VERSION_TAG="${CONNECTOR_VERSION}-dev.${hash}"
|
||||
hash=$(git rev-parse --short=7 HEAD)
|
||||
CONNECTOR_VERSION_TAG="${CONNECTOR_VERSION}-preview.${hash}"
|
||||
echo "🏷 Using pre-release tag: $CONNECTOR_VERSION_TAG"
|
||||
else
|
||||
CONNECTOR_VERSION_TAG="$CONNECTOR_VERSION"
|
||||
|
||||
2
.github/pr-welcome-community.md
vendored
2
.github/pr-welcome-community.md
vendored
@@ -21,7 +21,7 @@ As needed or by request, Airbyte Maintainers can execute the following slash com
|
||||
- `/run-live-tests` - Runs live tests for the modified connector(s).
|
||||
- `/run-regression-tests` - Runs regression tests for the modified connector(s).
|
||||
- `/build-connector-images` - Builds and publishes a pre-release docker image for the modified connector(s).
|
||||
- `/publish-connectors-prerelease` - Publishes pre-release connector builds (tagged as `{version}-dev.{git-sha}`) for all modified connectors in the PR.
|
||||
- `/publish-connectors-prerelease` - Publishes pre-release connector builds (tagged as `{version}-preview.{git-sha}`) for all modified connectors in the PR.
|
||||
|
||||
If you have any questions, feel free to ask in the PR comments or join our [Slack community](https://airbytehq.slack.com/).
|
||||
|
||||
|
||||
6
.github/pr-welcome-internal.md
vendored
6
.github/pr-welcome-internal.md
vendored
@@ -28,7 +28,11 @@ Airbyte Maintainers (that's you!) can execute the following slash commands on yo
|
||||
- `/run-live-tests` - Runs live tests for the modified connector(s).
|
||||
- `/run-regression-tests` - Runs regression tests for the modified connector(s).
|
||||
- `/build-connector-images` - Builds and publishes a pre-release docker image for the modified connector(s).
|
||||
- `/publish-connectors-prerelease` - Publishes pre-release connector builds (tagged as `{version}-dev.{git-sha}`) for all modified connectors in the PR.
|
||||
- `/publish-connectors-prerelease` - Publishes pre-release connector builds (tagged as `{version}-preview.{git-sha}`) for all modified connectors in the PR.
|
||||
- Connector release lifecycle (AI-powered):
|
||||
- `/ai-prove-fix` - Runs prerelease readiness checks, including testing against customer connections.
|
||||
- `/ai-canary-prerelease` - Rolls out prerelease to 5-10 connections for canary testing.
|
||||
- `/ai-release-watch` - Monitors rollout post-release and tracks sync success rates.
|
||||
- JVM connectors:
|
||||
- `/update-connector-cdk-version connector=<CONNECTOR_NAME>` - Updates the specified connector to the latest CDK version.
|
||||
Example: `/update-connector-cdk-version connector=destination-bigquery`
|
||||
|
||||
72
.github/workflows/ai-canary-prerelease-command.yml
vendored
Normal file
72
.github/workflows/ai-canary-prerelease-command.yml
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
name: AI Canary Prerelease Command
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
pr:
|
||||
description: "Pull request number (if triggered from a PR)"
|
||||
type: number
|
||||
required: false
|
||||
comment-id:
|
||||
description: "The comment-id of the slash command. Used to update the comment with the status."
|
||||
required: false
|
||||
repo:
|
||||
description: "Repo (passed by slash command dispatcher)"
|
||||
required: false
|
||||
default: "airbytehq/airbyte"
|
||||
gitref:
|
||||
description: "Git ref (passed by slash command dispatcher)"
|
||||
required: false
|
||||
|
||||
run-name: "AI Canary Prerelease for PR #${{ github.event.inputs.pr }}"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: read
|
||||
|
||||
jobs:
|
||||
ai-canary-prerelease:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Get job variables
|
||||
id: job-vars
|
||||
run: |
|
||||
echo "run-url=https://github.com/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Authenticate as GitHub App
|
||||
uses: actions/create-github-app-token@v2
|
||||
id: get-app-token
|
||||
with:
|
||||
owner: "airbytehq"
|
||||
repositories: "airbyte,oncall"
|
||||
app-id: ${{ secrets.OCTAVIA_BOT_APP_ID }}
|
||||
private-key: ${{ secrets.OCTAVIA_BOT_PRIVATE_KEY }}
|
||||
|
||||
- name: Post start comment
|
||||
if: inputs.comment-id != ''
|
||||
uses: peter-evans/create-or-update-comment@v4
|
||||
with:
|
||||
token: ${{ steps.get-app-token.outputs.token }}
|
||||
comment-id: ${{ inputs.comment-id }}
|
||||
issue-number: ${{ inputs.pr }}
|
||||
body: |
|
||||
> **AI Canary Prerelease Started**
|
||||
>
|
||||
> Rolling out to 5-10 connections, watching results, and reporting findings.
|
||||
> [View workflow run](${{ steps.job-vars.outputs.run-url }})
|
||||
|
||||
- name: Run AI Canary Prerelease
|
||||
uses: aaronsteers/devin-action@main
|
||||
with:
|
||||
comment-id: ${{ inputs.comment-id }}
|
||||
issue-number: ${{ inputs.pr }}
|
||||
playbook-macro: "!canary_prerelease"
|
||||
devin-token: ${{ secrets.DEVIN_AI_API_KEY }}
|
||||
github-token: ${{ steps.get-app-token.outputs.token }}
|
||||
start-message: "🐤 **AI Canary Prerelease session starting...** Rolling out to 5-10 connections, watching results, and reporting findings. [View playbook](https://github.com/airbytehq/oncall/blob/main/prompts/playbooks/canary_prerelease.md)"
|
||||
tags: |
|
||||
ai-oncall
|
||||
72
.github/workflows/ai-prove-fix-command.yml
vendored
Normal file
72
.github/workflows/ai-prove-fix-command.yml
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
name: AI Prove Fix Command
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
pr:
|
||||
description: "Pull request number (if triggered from a PR)"
|
||||
type: number
|
||||
required: false
|
||||
comment-id:
|
||||
description: "The comment-id of the slash command. Used to update the comment with the status."
|
||||
required: false
|
||||
repo:
|
||||
description: "Repo (passed by slash command dispatcher)"
|
||||
required: false
|
||||
default: "airbytehq/airbyte"
|
||||
gitref:
|
||||
description: "Git ref (passed by slash command dispatcher)"
|
||||
required: false
|
||||
|
||||
run-name: "AI Prove Fix for PR #${{ github.event.inputs.pr }}"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: read
|
||||
|
||||
jobs:
|
||||
ai-prove-fix:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Get job variables
|
||||
id: job-vars
|
||||
run: |
|
||||
echo "run-url=https://github.com/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Authenticate as GitHub App
|
||||
uses: actions/create-github-app-token@v2
|
||||
id: get-app-token
|
||||
with:
|
||||
owner: "airbytehq"
|
||||
repositories: "airbyte,oncall"
|
||||
app-id: ${{ secrets.OCTAVIA_BOT_APP_ID }}
|
||||
private-key: ${{ secrets.OCTAVIA_BOT_PRIVATE_KEY }}
|
||||
|
||||
- name: Post start comment
|
||||
if: inputs.comment-id != ''
|
||||
uses: peter-evans/create-or-update-comment@v4
|
||||
with:
|
||||
token: ${{ steps.get-app-token.outputs.token }}
|
||||
comment-id: ${{ inputs.comment-id }}
|
||||
issue-number: ${{ inputs.pr }}
|
||||
body: |
|
||||
> **AI Prove Fix Started**
|
||||
>
|
||||
> Running readiness checks and testing against customer connections.
|
||||
> [View workflow run](${{ steps.job-vars.outputs.run-url }})
|
||||
|
||||
- name: Run AI Prove Fix
|
||||
uses: aaronsteers/devin-action@main
|
||||
with:
|
||||
comment-id: ${{ inputs.comment-id }}
|
||||
issue-number: ${{ inputs.pr }}
|
||||
playbook-macro: "!prove_fix"
|
||||
devin-token: ${{ secrets.DEVIN_AI_API_KEY }}
|
||||
github-token: ${{ steps.get-app-token.outputs.token }}
|
||||
start-message: "🔍 **AI Prove Fix session starting...** Running readiness checks and testing against customer connections. [View playbook](https://github.com/airbytehq/oncall/blob/main/prompts/playbooks/prove_fix.md)"
|
||||
tags: |
|
||||
ai-oncall
|
||||
72
.github/workflows/ai-release-watch-command.yml
vendored
Normal file
72
.github/workflows/ai-release-watch-command.yml
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
name: AI Release Watch Command
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
pr:
|
||||
description: "Pull request number (if triggered from a PR)"
|
||||
type: number
|
||||
required: false
|
||||
comment-id:
|
||||
description: "The comment-id of the slash command. Used to update the comment with the status."
|
||||
required: false
|
||||
repo:
|
||||
description: "Repo (passed by slash command dispatcher)"
|
||||
required: false
|
||||
default: "airbytehq/airbyte"
|
||||
gitref:
|
||||
description: "Git ref (passed by slash command dispatcher)"
|
||||
required: false
|
||||
|
||||
run-name: "AI Release Watch for PR #${{ github.event.inputs.pr }}"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: read
|
||||
|
||||
jobs:
|
||||
ai-release-watch:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Get job variables
|
||||
id: job-vars
|
||||
run: |
|
||||
echo "run-url=https://github.com/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Authenticate as GitHub App
|
||||
uses: actions/create-github-app-token@v2
|
||||
id: get-app-token
|
||||
with:
|
||||
owner: "airbytehq"
|
||||
repositories: "airbyte,oncall"
|
||||
app-id: ${{ secrets.OCTAVIA_BOT_APP_ID }}
|
||||
private-key: ${{ secrets.OCTAVIA_BOT_PRIVATE_KEY }}
|
||||
|
||||
- name: Post start comment
|
||||
if: inputs.comment-id != ''
|
||||
uses: peter-evans/create-or-update-comment@v4
|
||||
with:
|
||||
token: ${{ steps.get-app-token.outputs.token }}
|
||||
comment-id: ${{ inputs.comment-id }}
|
||||
issue-number: ${{ inputs.pr }}
|
||||
body: |
|
||||
> **AI Release Watch Started**
|
||||
>
|
||||
> Monitoring rollout and tracking sync success rates.
|
||||
> [View workflow run](${{ steps.job-vars.outputs.run-url }})
|
||||
|
||||
- name: Run AI Release Watch
|
||||
uses: aaronsteers/devin-action@main
|
||||
with:
|
||||
comment-id: ${{ inputs.comment-id }}
|
||||
issue-number: ${{ inputs.pr }}
|
||||
playbook-macro: "!release_watch"
|
||||
devin-token: ${{ secrets.DEVIN_AI_API_KEY }}
|
||||
github-token: ${{ steps.get-app-token.outputs.token }}
|
||||
start-message: "👁️ **AI Release Watch session starting...** Monitoring rollout and tracking sync success rates. [View playbook](https://github.com/airbytehq/oncall/blob/main/prompts/playbooks/release_watch.md)"
|
||||
tags: |
|
||||
ai-oncall
|
||||
2
.github/workflows/autodoc.yml
vendored
2
.github/workflows/autodoc.yml
vendored
@@ -104,7 +104,7 @@ jobs:
|
||||
if: steps.check-support-level.outputs.metadata_file == 'true' && steps.check-support-level.outputs.community_support == 'true'
|
||||
env:
|
||||
PROMPT_TEXT: "The commit to review is ${{ github.sha }}. This commit was pushed to master and may contain connector changes that need documentation updates."
|
||||
uses: aaronsteers/devin-action@0d74d6d9ff1b16ada5966dc31af53a9d155759f4 # Pinned to specific commit for security
|
||||
uses: aaronsteers/devin-action@98d15ae93d1848914f5ab8e9ce45341182958d27 # v0.1.7 - Pinned to specific commit for security
|
||||
with:
|
||||
devin-token: ${{ secrets.DEVIN_AI_API_KEY }}
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
28
.github/workflows/label-community-prs.yml
vendored
Normal file
28
.github/workflows/label-community-prs.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Label Community PRs
|
||||
|
||||
# This workflow automatically adds the "community" label to PRs from forks.
|
||||
# This enables automatic tracking on the Community PRs project board.
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types:
|
||||
- opened
|
||||
- reopened
|
||||
|
||||
jobs:
|
||||
label-community-pr:
|
||||
name: Add "Community" Label to PR
|
||||
# Only run for PRs from forks
|
||||
if: github.event.pull_request.head.repo.fork == true
|
||||
runs-on: ubuntu-24.04
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Add community label
|
||||
# This action uses GitHub's addLabels API, which is idempotent.
|
||||
# If the label already exists, the API call succeeds without error.
|
||||
uses: actions-ecosystem/action-add-labels@bd52874380e3909a1ac983768df6976535ece7f8 # v1.1.3
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
labels: community
|
||||
@@ -3,7 +3,7 @@ name: Publish Connectors Pre-release
|
||||
# It can be triggered via the /publish-connectors-prerelease slash command from PR comments,
|
||||
# or via the MCP tool `publish_connector_to_airbyte_registry`.
|
||||
#
|
||||
# Pre-release versions are tagged with the format: {version}-dev.{10-char-git-sha}
|
||||
# Pre-release versions are tagged with the format: {version}-preview.{7-char-git-sha}
|
||||
# These versions are NOT eligible for semver auto-advancement but ARE available
|
||||
# for version pinning via the scoped_configuration API.
|
||||
#
|
||||
@@ -66,7 +66,7 @@ jobs:
|
||||
- name: Get short SHA
|
||||
id: get-sha
|
||||
run: |
|
||||
SHORT_SHA=$(git rev-parse --short=10 HEAD)
|
||||
SHORT_SHA=$(git rev-parse --short=7 HEAD)
|
||||
echo "short-sha=$SHORT_SHA" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Get job variables
|
||||
@@ -135,7 +135,7 @@ jobs:
|
||||
> Publishing pre-release build for connector `${{ steps.resolve-connector.outputs.connector-name }}`.
|
||||
> Branch: `${{ inputs.gitref }}`
|
||||
>
|
||||
> Pre-release versions will be tagged as `{version}-dev.${{ steps.get-sha.outputs.short-sha }}`
|
||||
> Pre-release versions will be tagged as `{version}-preview.${{ steps.get-sha.outputs.short-sha }}`
|
||||
> and are available for version pinning via the scoped_configuration API.
|
||||
>
|
||||
> [View workflow run](${{ steps.job-vars.outputs.run-url }})
|
||||
@@ -147,6 +147,7 @@ jobs:
|
||||
with:
|
||||
connectors: ${{ format('--name={0}', needs.init.outputs.connector-name) }}
|
||||
release-type: pre-release
|
||||
gitref: ${{ inputs.gitref }}
|
||||
secrets: inherit
|
||||
|
||||
post-completion:
|
||||
@@ -176,13 +177,12 @@ jobs:
|
||||
id: message-vars
|
||||
run: |
|
||||
CONNECTOR_NAME="${{ needs.init.outputs.connector-name }}"
|
||||
SHORT_SHA="${{ needs.init.outputs.short-sha }}"
|
||||
VERSION="${{ needs.init.outputs.connector-version }}"
|
||||
# Use the actual docker-image-tag from the publish workflow output
|
||||
DOCKER_TAG="${{ needs.publish.outputs.docker-image-tag }}"
|
||||
|
||||
if [[ -n "$VERSION" ]]; then
|
||||
DOCKER_TAG="${VERSION}-dev.${SHORT_SHA}"
|
||||
else
|
||||
DOCKER_TAG="{version}-dev.${SHORT_SHA}"
|
||||
if [[ -z "$DOCKER_TAG" ]]; then
|
||||
echo "::error::docker-image-tag output is missing from publish workflow. This is unexpected."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "connector_name=$CONNECTOR_NAME" >> $GITHUB_OUTPUT
|
||||
|
||||
19
.github/workflows/publish_connectors.yml
vendored
19
.github/workflows/publish_connectors.yml
vendored
@@ -21,6 +21,14 @@ on:
|
||||
required: false
|
||||
default: false
|
||||
type: boolean
|
||||
gitref:
|
||||
description: "Git ref (branch or SHA) to build connectors from. Used by pre-release workflow to build from PR branches."
|
||||
required: false
|
||||
type: string
|
||||
outputs:
|
||||
docker-image-tag:
|
||||
description: "Docker image tag used when publishing. For single-connector callers only; multi-connector callers should not rely on this output."
|
||||
value: ${{ jobs.publish_connector_registry_entries.outputs.docker-image-tag }}
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
connectors:
|
||||
@@ -48,6 +56,7 @@ jobs:
|
||||
# v4
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955
|
||||
with:
|
||||
ref: ${{ inputs.gitref || '' }}
|
||||
fetch-depth: 2 # Required so we can conduct a diff from the previous commit to understand what connectors have changed.
|
||||
submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully.
|
||||
- name: List connectors to publish [manual]
|
||||
@@ -105,6 +114,7 @@ jobs:
|
||||
# v4
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955
|
||||
with:
|
||||
ref: ${{ inputs.gitref || '' }}
|
||||
fetch-depth: 2 # Required so we can conduct a diff from the previous commit to understand what connectors have changed.
|
||||
submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully.
|
||||
|
||||
@@ -250,11 +260,14 @@ jobs:
|
||||
max-parallel: 5
|
||||
# Allow all jobs to run, even if one fails
|
||||
fail-fast: false
|
||||
outputs:
|
||||
docker-image-tag: ${{ steps.connector-metadata.outputs.docker-image-tag }}
|
||||
steps:
|
||||
- name: Checkout Airbyte
|
||||
# v4
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955
|
||||
with:
|
||||
ref: ${{ inputs.gitref || '' }}
|
||||
fetch-depth: 2 # Required so we can conduct a diff from the previous commit to understand what connectors have changed.
|
||||
submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully.
|
||||
|
||||
@@ -292,8 +305,8 @@ jobs:
|
||||
echo "connector-version=$(poe -qq get-version)" | tee -a $GITHUB_OUTPUT
|
||||
CONNECTOR_VERSION=$(poe -qq get-version)
|
||||
if [[ "${{ inputs.release-type }}" == "pre-release" ]]; then
|
||||
hash=$(git rev-parse --short=10 HEAD)
|
||||
echo "docker-image-tag=${CONNECTOR_VERSION}-dev.${hash}" | tee -a $GITHUB_OUTPUT
|
||||
hash=$(git rev-parse --short=7 HEAD)
|
||||
echo "docker-image-tag=${CONNECTOR_VERSION}-preview.${hash}" | tee -a $GITHUB_OUTPUT
|
||||
echo "release-type-flag=--pre-release" | tee -a $GITHUB_OUTPUT
|
||||
else
|
||||
echo "docker-image-tag=${CONNECTOR_VERSION}" | tee -a $GITHUB_OUTPUT
|
||||
@@ -349,6 +362,7 @@ jobs:
|
||||
# v4
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955
|
||||
with:
|
||||
ref: ${{ inputs.gitref || '' }}
|
||||
submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully.
|
||||
- name: Match GitHub User to Slack User
|
||||
id: match-github-to-slack-user
|
||||
@@ -381,6 +395,7 @@ jobs:
|
||||
# v4
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955
|
||||
with:
|
||||
ref: ${{ inputs.gitref || '' }}
|
||||
submodules: true # Required for the enterprise repo since it uses a submodule that needs to exist for this workflow to run successfully.
|
||||
- name: Notify PagerDuty
|
||||
id: pager-duty
|
||||
|
||||
3
.github/workflows/slash-commands.yml
vendored
3
.github/workflows/slash-commands.yml
vendored
@@ -35,6 +35,9 @@ jobs:
|
||||
issue-type: both
|
||||
|
||||
commands: |
|
||||
ai-canary-prerelease
|
||||
ai-prove-fix
|
||||
ai-release-watch
|
||||
approve-regression-tests
|
||||
bump-bulk-cdk-version
|
||||
bump-progressive-rollout-version
|
||||
|
||||
70
.github/workflows/sync-ai-connector-docs.yml
vendored
Normal file
70
.github/workflows/sync-ai-connector-docs.yml
vendored
Normal file
@@ -0,0 +1,70 @@
|
||||
name: Sync Agent Connector Docs
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 */2 * * *" # Every 2 hours
|
||||
workflow_dispatch: # Manual trigger
|
||||
|
||||
jobs:
|
||||
sync-docs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout airbyte repo
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
|
||||
- name: Checkout airbyte-agent-connectors
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
repository: airbytehq/airbyte-agent-connectors
|
||||
path: agent-connectors-source
|
||||
|
||||
- name: Sync connector docs
|
||||
run: |
|
||||
DEST_DIR="docs/ai-agents/connectors"
|
||||
mkdir -p "$DEST_DIR"
|
||||
|
||||
for connector_dir in agent-connectors-source/connectors/*/; do
|
||||
connector=$(basename "$connector_dir")
|
||||
|
||||
# Only delete/recreate the specific connector subdirectory
|
||||
# This leaves any files directly in $DEST_DIR untouched
|
||||
rm -rf "$DEST_DIR/$connector"
|
||||
mkdir -p "$DEST_DIR/$connector"
|
||||
|
||||
# Copy all markdown files for this connector
|
||||
for md_file in "$connector_dir"/*.md; do
|
||||
if [ -f "$md_file" ]; then
|
||||
cp "$md_file" "$DEST_DIR/$connector/"
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
echo "Synced $(ls -d $DEST_DIR/*/ 2>/dev/null | wc -l) connectors"
|
||||
|
||||
- name: Cleanup temporary checkout
|
||||
run: rm -rf agent-connectors-source
|
||||
|
||||
- name: Authenticate as GitHub App
|
||||
uses: actions/create-github-app-token@v2
|
||||
id: get-app-token
|
||||
with:
|
||||
owner: "airbytehq"
|
||||
repositories: "airbyte"
|
||||
app-id: ${{ secrets.OCTAVIA_BOT_APP_ID }}
|
||||
private-key: ${{ secrets.OCTAVIA_BOT_PRIVATE_KEY }}
|
||||
|
||||
- name: Create PR if changes
|
||||
uses: peter-evans/create-pull-request@0979079bc20c05bbbb590a56c21c4e2b1d1f1bbe # v6
|
||||
with:
|
||||
token: ${{ steps.get-app-token.outputs.token }}
|
||||
commit-message: "docs: sync agent connector docs from airbyte-agent-connectors repo"
|
||||
branch: auto-sync-ai-connector-docs
|
||||
delete-branch: true
|
||||
title: "docs: sync agent connector docs from airbyte-agent-connectors repo"
|
||||
body: |
|
||||
Automated sync of agent connector docs from airbyte-agent-connectors.
|
||||
|
||||
This PR was automatically created by the sync-agent-connector-docs workflow.
|
||||
labels: |
|
||||
documentation
|
||||
auto-merge
|
||||
3
.markdownlintignore
Normal file
3
.markdownlintignore
Normal file
@@ -0,0 +1,3 @@
|
||||
# Ignore auto-generated connector documentation files synced from airbyte-agent-connectors repo
|
||||
# These files are generated and have formatting that doesn't conform to markdownlint rules
|
||||
docs/ai-agents/connectors/**
|
||||
@@ -1,3 +1,21 @@
|
||||
## Version 0.1.91
|
||||
|
||||
load cdk: upsert records test uses proper target schema
|
||||
|
||||
## Version 0.1.90
|
||||
|
||||
load cdk: components tests: data coercion tests cover all data types
|
||||
|
||||
## Version 0.1.89
|
||||
|
||||
load cdk: components tests: data coercion tests for int+number
|
||||
|
||||
## Version 0.1.88
|
||||
|
||||
**Load CDK**
|
||||
|
||||
* Add CDC_CURSOR_COLUMN_NAME constant.
|
||||
|
||||
## Version 0.1.87
|
||||
|
||||
**Load CDK**
|
||||
|
||||
@@ -4,4 +4,13 @@
|
||||
|
||||
package io.airbyte.cdk.load.table
|
||||
|
||||
/**
|
||||
* CDC meta column names.
|
||||
*
|
||||
* Note: These CDC column names are brittle as they are separate yet coupled to the logic sources
|
||||
* use to generate these column names. See
|
||||
* [io.airbyte.integrations.source.mssql.MsSqlSourceOperations.MsSqlServerCdcMetaFields] for an
|
||||
* example.
|
||||
*/
|
||||
const val CDC_DELETED_AT_COLUMN = "_ab_cdc_deleted_at"
|
||||
const val CDC_CURSOR_COLUMN = "_ab_cdc_cursor"
|
||||
|
||||
@@ -0,0 +1,859 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.cdk.load.component
|
||||
|
||||
import io.airbyte.cdk.load.data.AirbyteValue
|
||||
import io.airbyte.cdk.load.data.ArrayValue
|
||||
import io.airbyte.cdk.load.data.DateValue
|
||||
import io.airbyte.cdk.load.data.IntegerValue
|
||||
import io.airbyte.cdk.load.data.NullValue
|
||||
import io.airbyte.cdk.load.data.NumberValue
|
||||
import io.airbyte.cdk.load.data.ObjectValue
|
||||
import io.airbyte.cdk.load.data.StringValue
|
||||
import io.airbyte.cdk.load.data.TimeWithTimezoneValue
|
||||
import io.airbyte.cdk.load.data.TimeWithoutTimezoneValue
|
||||
import io.airbyte.cdk.load.data.TimestampWithTimezoneValue
|
||||
import io.airbyte.cdk.load.data.TimestampWithoutTimezoneValue
|
||||
import io.airbyte.cdk.load.dataflow.transform.ValueCoercer
|
||||
import io.airbyte.cdk.load.util.serializeToString
|
||||
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Reason
|
||||
import java.math.BigDecimal
|
||||
import java.math.BigInteger
|
||||
import java.time.LocalDate
|
||||
import java.time.LocalDateTime
|
||||
import java.time.OffsetDateTime
|
||||
import java.time.format.DateTimeFormatter
|
||||
import java.time.format.DateTimeFormatterBuilder
|
||||
import java.time.format.SignStyle
|
||||
import java.time.temporal.ChronoField
|
||||
import org.junit.jupiter.params.provider.Arguments
|
||||
|
||||
/*
|
||||
* This file defines "interesting values" for all data types, along with expected behavior for those values.
|
||||
* You're free to define your own values/behavior depending on the destination, but it's recommended
|
||||
* that you try to match behavior to an existing fixture.
|
||||
*
|
||||
* Classes also include some convenience functions for JUnit. For example, you could annotate your
|
||||
* method with:
|
||||
* ```kotlin
|
||||
* @ParameterizedTest
|
||||
* @MethodSource("io.airbyte.cdk.load.component.DataCoercionIntegerFixtures#int64")
|
||||
* ```
|
||||
*
|
||||
* By convention, all fixtures are declared as:
|
||||
* 1. One or more `val <name>: List<Pair<AirbyteValue, Any?>>` (each pair representing the input value,
|
||||
* and the expected output value)
|
||||
* 2. One or more `fun <name>(): List<Arguments> = <name>.toArgs()`, which can be provided to JUnit's MethodSource
|
||||
*
|
||||
* If you need to mutate fixtures in some way, you should reference the `val`, and use the `toArgs()`
|
||||
* extension function to convert it to JUnit's Arguments class. See [DataCoercionIntegerFixtures.int64AsBigInteger]
|
||||
* for an example.
|
||||
*/
|
||||
|
||||
object DataCoercionIntegerFixtures {
|
||||
// "9".repeat(38)
|
||||
val numeric38_0Max = bigint("99999999999999999999999999999999999999")
|
||||
val numeric38_0Min = bigint("-99999999999999999999999999999999999999")
|
||||
|
||||
const val ZERO = "0"
|
||||
const val ONE = "1"
|
||||
const val NEGATIVE_ONE = "-1"
|
||||
const val FORTY_TWO = "42"
|
||||
const val NEGATIVE_FORTY_TWO = "-42"
|
||||
const val INT32_MAX = "int32 max"
|
||||
const val INT32_MIN = "int32 min"
|
||||
const val INT32_MAX_PLUS_ONE = "int32_max + 1"
|
||||
const val INT32_MIN_MINUS_ONE = "int32_min - 1"
|
||||
const val INT64_MAX = "int64 max"
|
||||
const val INT64_MIN = "int64 min"
|
||||
const val INT64_MAX_PLUS_ONE = "int64_max + 1"
|
||||
const val INT64_MIN_MINUS_1 = "int64_min - 1"
|
||||
const val NUMERIC_38_0_MAX = "numeric(38,0) max"
|
||||
const val NUMERIC_38_0_MIN = "numeric(38,0) min"
|
||||
const val NUMERIC_38_0_MAX_PLUS_ONE = "numeric(38,0)_max + 1"
|
||||
const val NUMERIC_38_0_MIN_MINUS_ONE = "numeric(38,0)_min - 1"
|
||||
|
||||
/**
|
||||
* Many destinations use int64 to represent integers. In this case, we null out any value beyond
|
||||
* Long.MIN/MAX_VALUE.
|
||||
*/
|
||||
val int64 =
|
||||
listOf(
|
||||
case(NULL, NullValue, null),
|
||||
case(ZERO, IntegerValue(0), 0L),
|
||||
case(ONE, IntegerValue(1), 1L),
|
||||
case(NEGATIVE_ONE, IntegerValue(-1), -1L),
|
||||
case(FORTY_TWO, IntegerValue(42), 42L),
|
||||
case(NEGATIVE_FORTY_TWO, IntegerValue(-42), -42L),
|
||||
// int32 bounds, and slightly out of bounds
|
||||
case(INT32_MAX, IntegerValue(Integer.MAX_VALUE.toLong()), Integer.MAX_VALUE.toLong()),
|
||||
case(INT32_MIN, IntegerValue(Integer.MIN_VALUE.toLong()), Integer.MIN_VALUE.toLong()),
|
||||
case(
|
||||
INT32_MAX_PLUS_ONE,
|
||||
IntegerValue(Integer.MAX_VALUE.toLong() + 1),
|
||||
Integer.MAX_VALUE.toLong() + 1
|
||||
),
|
||||
case(
|
||||
INT32_MIN_MINUS_ONE,
|
||||
IntegerValue(Integer.MIN_VALUE.toLong() - 1),
|
||||
Integer.MIN_VALUE.toLong() - 1
|
||||
),
|
||||
// int64 bounds, and slightly out of bounds
|
||||
case(INT64_MAX, IntegerValue(Long.MAX_VALUE), Long.MAX_VALUE),
|
||||
case(INT64_MIN, IntegerValue(Long.MIN_VALUE), Long.MIN_VALUE),
|
||||
// values out of int64 bounds are nulled
|
||||
case(
|
||||
INT64_MAX_PLUS_ONE,
|
||||
IntegerValue(bigint(Long.MAX_VALUE) + BigInteger.ONE),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
INT64_MIN_MINUS_1,
|
||||
IntegerValue(bigint(Long.MIN_VALUE) - BigInteger.ONE),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
// NUMERIC(38, 9) bounds, and slightly out of bounds
|
||||
// (these are all out of bounds for an int64 value, so they all get nulled)
|
||||
case(
|
||||
NUMERIC_38_0_MAX,
|
||||
IntegerValue(numeric38_0Max),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
NUMERIC_38_0_MIN,
|
||||
IntegerValue(numeric38_0Min),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
NUMERIC_38_0_MAX_PLUS_ONE,
|
||||
IntegerValue(numeric38_0Max + BigInteger.ONE),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
NUMERIC_38_0_MIN_MINUS_ONE,
|
||||
IntegerValue(numeric38_0Min - BigInteger.ONE),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
)
|
||||
|
||||
/**
|
||||
* Many destination warehouses represent integers as a fixed-point type with 38 digits of
|
||||
* precision. In this case, we only need to null out numbers larger than `1e38 - 1` / smaller
|
||||
* than `-1e38 + 1`.
|
||||
*/
|
||||
val numeric38_0 =
|
||||
listOf(
|
||||
case(NULL, NullValue, null),
|
||||
case(ZERO, IntegerValue(0), bigint(0L)),
|
||||
case(ONE, IntegerValue(1), bigint(1L)),
|
||||
case(NEGATIVE_ONE, IntegerValue(-1), bigint(-1L)),
|
||||
case(FORTY_TWO, IntegerValue(42), bigint(42L)),
|
||||
case(NEGATIVE_FORTY_TWO, IntegerValue(-42), bigint(-42L)),
|
||||
// int32 bounds, and slightly out of bounds
|
||||
case(
|
||||
INT32_MAX,
|
||||
IntegerValue(Integer.MAX_VALUE.toLong()),
|
||||
bigint(Integer.MAX_VALUE.toLong())
|
||||
),
|
||||
case(
|
||||
INT32_MIN,
|
||||
IntegerValue(Integer.MIN_VALUE.toLong()),
|
||||
bigint(Integer.MIN_VALUE.toLong())
|
||||
),
|
||||
case(
|
||||
INT32_MAX_PLUS_ONE,
|
||||
IntegerValue(Integer.MAX_VALUE.toLong() + 1),
|
||||
bigint(Integer.MAX_VALUE.toLong() + 1)
|
||||
),
|
||||
case(
|
||||
INT32_MIN_MINUS_ONE,
|
||||
IntegerValue(Integer.MIN_VALUE.toLong() - 1),
|
||||
bigint(Integer.MIN_VALUE.toLong() - 1)
|
||||
),
|
||||
// int64 bounds, and slightly out of bounds
|
||||
case(INT64_MAX, IntegerValue(Long.MAX_VALUE), bigint(Long.MAX_VALUE)),
|
||||
case(INT64_MIN, IntegerValue(Long.MIN_VALUE), bigint(Long.MIN_VALUE)),
|
||||
case(
|
||||
INT64_MAX_PLUS_ONE,
|
||||
IntegerValue(bigint(Long.MAX_VALUE) + BigInteger.ONE),
|
||||
bigint(Long.MAX_VALUE) + BigInteger.ONE
|
||||
),
|
||||
case(
|
||||
INT64_MIN_MINUS_1,
|
||||
IntegerValue(bigint(Long.MIN_VALUE) - BigInteger.ONE),
|
||||
bigint(Long.MIN_VALUE) - BigInteger.ONE
|
||||
),
|
||||
// NUMERIC(38, 9) bounds, and slightly out of bounds
|
||||
case(NUMERIC_38_0_MAX, IntegerValue(numeric38_0Max), numeric38_0Max),
|
||||
case(NUMERIC_38_0_MIN, IntegerValue(numeric38_0Min), numeric38_0Min),
|
||||
// These values exceed the 38-digit range, so they get nulled out
|
||||
case(
|
||||
NUMERIC_38_0_MAX_PLUS_ONE,
|
||||
IntegerValue(numeric38_0Max + BigInteger.ONE),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
NUMERIC_38_0_MIN_MINUS_ONE,
|
||||
IntegerValue(numeric38_0Min - BigInteger.ONE),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
)
|
||||
|
||||
@JvmStatic fun int64() = int64.toArgs()
|
||||
|
||||
/**
|
||||
* Convenience fixture if your [TestTableOperationsClient] returns integers as [BigInteger]
|
||||
* rather than [Long].
|
||||
*/
|
||||
@JvmStatic
|
||||
fun int64AsBigInteger() =
|
||||
int64.map { it.copy(outputValue = it.outputValue?.let { bigint(it as Long) }) }
|
||||
|
||||
/**
|
||||
* Convenience fixture if your [TestTableOperationsClient] returns integers as [BigDecimal]
|
||||
* rather than [Long].
|
||||
*/
|
||||
@JvmStatic
|
||||
fun int64AsBigDecimal() =
|
||||
int64.map { it.copy(outputValue = it.outputValue?.let { BigDecimal.valueOf(it as Long) }) }
|
||||
|
||||
@JvmStatic fun numeric38_0() = numeric38_0.toArgs()
|
||||
}
|
||||
|
||||
object DataCoercionNumberFixtures {
|
||||
val numeric38_9Max = bigdec("99999999999999999999999999999.999999999")
|
||||
val numeric38_9Min = bigdec("-99999999999999999999999999999.999999999")
|
||||
|
||||
const val ZERO = "0"
|
||||
const val ONE = "1"
|
||||
const val NEGATIVE_ONE = "-1"
|
||||
const val ONE_HUNDRED_TWENTY_THREE_POINT_FOUR = "123.4"
|
||||
const val NEGATIVE_ONE_HUNDRED_TWENTY_THREE_POINT_FOUR = "123.4"
|
||||
const val POSITIVE_HIGH_PRECISION_FLOAT = "positive high-precision float"
|
||||
const val NEGATIVE_HIGH_PRECISION_FLOAT = "negative high-precision float"
|
||||
const val NUMERIC_38_9_MAX = "numeric(38,9) max"
|
||||
const val NUMERIC_38_9_MIN = "numeric(38,9) min"
|
||||
const val SMALLEST_POSITIVE_FLOAT32 = "smallest positive float32"
|
||||
const val SMALLEST_NEGATIVE_FLOAT32 = "smallest negative float32"
|
||||
const val LARGEST_POSITIVE_FLOAT32 = "largest positive float32"
|
||||
const val LARGEST_NEGATIVE_FLOAT32 = "largest negative float32"
|
||||
const val SMALLEST_POSITIVE_FLOAT64 = "smallest positive float64"
|
||||
const val SMALLEST_NEGATIVE_FLOAT64 = "smallest negative float64"
|
||||
const val LARGEST_POSITIVE_FLOAT64 = "largest positive float64"
|
||||
const val LARGEST_NEGATIVE_FLOAT64 = "largest negative float64"
|
||||
const val SLIGHTLY_ABOVE_LARGEST_POSITIVE_FLOAT64 = "slightly above largest positive float64"
|
||||
const val SLIGHTLY_BELOW_LARGEST_NEGATIVE_FLOAT64 = "slightly below largest negative float64"
|
||||
|
||||
val float64 =
|
||||
listOf(
|
||||
case(NULL, NullValue, null),
|
||||
case(ZERO, NumberValue(bigdec(0)), 0.0),
|
||||
case(ONE, NumberValue(bigdec(1)), 1.0),
|
||||
case(NEGATIVE_ONE, NumberValue(bigdec(-1)), -1.0),
|
||||
// This value isn't exactly representable as a float64
|
||||
// (the exact value is `123.400000000000005684341886080801486968994140625`)
|
||||
// but we should preserve the canonical representation
|
||||
case(ONE_HUNDRED_TWENTY_THREE_POINT_FOUR, NumberValue(bigdec("123.4")), 123.4),
|
||||
case(
|
||||
NEGATIVE_ONE_HUNDRED_TWENTY_THREE_POINT_FOUR,
|
||||
NumberValue(bigdec("-123.4")),
|
||||
-123.4
|
||||
),
|
||||
// These values have too much precision for a float64, so we round them
|
||||
case(
|
||||
POSITIVE_HIGH_PRECISION_FLOAT,
|
||||
NumberValue(bigdec("1234567890.1234567890123456789")),
|
||||
1234567890.1234567,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
NEGATIVE_HIGH_PRECISION_FLOAT,
|
||||
NumberValue(bigdec("-1234567890.1234567890123456789")),
|
||||
-1234567890.1234567,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
NUMERIC_38_9_MAX,
|
||||
NumberValue(numeric38_9Max),
|
||||
1.0E29,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
NUMERIC_38_9_MIN,
|
||||
NumberValue(numeric38_9Min),
|
||||
-1.0E29,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
// min/max_value are all positive values, so we need to manually test their negative
|
||||
// version
|
||||
case(
|
||||
SMALLEST_POSITIVE_FLOAT32,
|
||||
NumberValue(bigdec(Float.MIN_VALUE.toDouble())),
|
||||
Float.MIN_VALUE.toDouble()
|
||||
),
|
||||
case(
|
||||
SMALLEST_NEGATIVE_FLOAT32,
|
||||
NumberValue(bigdec(-Float.MIN_VALUE.toDouble())),
|
||||
-Float.MIN_VALUE.toDouble()
|
||||
),
|
||||
case(
|
||||
LARGEST_POSITIVE_FLOAT32,
|
||||
NumberValue(bigdec(Float.MAX_VALUE.toDouble())),
|
||||
Float.MAX_VALUE.toDouble()
|
||||
),
|
||||
case(
|
||||
LARGEST_NEGATIVE_FLOAT32,
|
||||
NumberValue(bigdec(-Float.MAX_VALUE.toDouble())),
|
||||
-Float.MAX_VALUE.toDouble()
|
||||
),
|
||||
case(
|
||||
SMALLEST_POSITIVE_FLOAT64,
|
||||
NumberValue(bigdec(Double.MIN_VALUE)),
|
||||
Double.MIN_VALUE
|
||||
),
|
||||
case(
|
||||
SMALLEST_NEGATIVE_FLOAT64,
|
||||
NumberValue(bigdec(-Double.MIN_VALUE)),
|
||||
-Double.MIN_VALUE
|
||||
),
|
||||
case(LARGEST_POSITIVE_FLOAT64, NumberValue(bigdec(Double.MAX_VALUE)), Double.MAX_VALUE),
|
||||
case(
|
||||
LARGEST_NEGATIVE_FLOAT64,
|
||||
NumberValue(bigdec(-Double.MAX_VALUE)),
|
||||
-Double.MAX_VALUE
|
||||
),
|
||||
// These values are out of bounds, so we null them
|
||||
case(
|
||||
SLIGHTLY_ABOVE_LARGEST_POSITIVE_FLOAT64,
|
||||
NumberValue(bigdec(Double.MAX_VALUE) + bigdec(Double.MIN_VALUE)),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
SLIGHTLY_BELOW_LARGEST_NEGATIVE_FLOAT64,
|
||||
NumberValue(bigdec(-Double.MAX_VALUE) - bigdec(Double.MIN_VALUE)),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
)
|
||||
|
||||
val numeric38_9 =
|
||||
listOf(
|
||||
case(NULL, NullValue, null),
|
||||
case(ZERO, NumberValue(bigdec(0)), bigdec(0.0)),
|
||||
case(ONE, NumberValue(bigdec(1)), bigdec(1.0)),
|
||||
case(NEGATIVE_ONE, NumberValue(bigdec(-1)), bigdec(-1.0)),
|
||||
// This value isn't exactly representable as a float64
|
||||
// (the exact value is `123.400000000000005684341886080801486968994140625`)
|
||||
// but it's perfectly fine as a numeric(38, 9)
|
||||
case(
|
||||
ONE_HUNDRED_TWENTY_THREE_POINT_FOUR,
|
||||
NumberValue(bigdec("123.4")),
|
||||
bigdec("123.4")
|
||||
),
|
||||
case(
|
||||
NEGATIVE_ONE_HUNDRED_TWENTY_THREE_POINT_FOUR,
|
||||
NumberValue(bigdec("-123.4")),
|
||||
bigdec("-123.4")
|
||||
),
|
||||
// These values have too much precision for a numeric(38, 9), so we round them
|
||||
case(
|
||||
POSITIVE_HIGH_PRECISION_FLOAT,
|
||||
NumberValue(bigdec("1234567890.1234567890123456789")),
|
||||
bigdec("1234567890.123456789"),
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
NEGATIVE_HIGH_PRECISION_FLOAT,
|
||||
NumberValue(bigdec("-1234567890.1234567890123456789")),
|
||||
bigdec("-1234567890.123456789"),
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
SMALLEST_POSITIVE_FLOAT32,
|
||||
NumberValue(bigdec(Float.MIN_VALUE.toDouble())),
|
||||
bigdec(0),
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
SMALLEST_NEGATIVE_FLOAT32,
|
||||
NumberValue(bigdec(-Float.MIN_VALUE.toDouble())),
|
||||
bigdec(0),
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
SMALLEST_POSITIVE_FLOAT64,
|
||||
NumberValue(bigdec(Double.MIN_VALUE)),
|
||||
bigdec(0),
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
SMALLEST_NEGATIVE_FLOAT64,
|
||||
NumberValue(bigdec(-Double.MIN_VALUE)),
|
||||
bigdec(0),
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
// numeric bounds are perfectly fine
|
||||
case(NUMERIC_38_9_MAX, NumberValue(numeric38_9Max), numeric38_9Max),
|
||||
case(NUMERIC_38_9_MIN, NumberValue(numeric38_9Min), numeric38_9Min),
|
||||
// These values are out of bounds, so we null them
|
||||
case(
|
||||
LARGEST_POSITIVE_FLOAT32,
|
||||
NumberValue(bigdec(Float.MAX_VALUE.toDouble())),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
LARGEST_NEGATIVE_FLOAT32,
|
||||
NumberValue(bigdec(-Float.MAX_VALUE.toDouble())),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
LARGEST_POSITIVE_FLOAT64,
|
||||
NumberValue(bigdec(Double.MAX_VALUE)),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
LARGEST_NEGATIVE_FLOAT64,
|
||||
NumberValue(bigdec(-Double.MAX_VALUE)),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
SLIGHTLY_ABOVE_LARGEST_POSITIVE_FLOAT64,
|
||||
NumberValue(bigdec(Double.MAX_VALUE) + bigdec(Double.MIN_VALUE)),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
SLIGHTLY_BELOW_LARGEST_NEGATIVE_FLOAT64,
|
||||
NumberValue(bigdec(-Double.MAX_VALUE) - bigdec(Double.MIN_VALUE)),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
)
|
||||
.map { it.copy(outputValue = (it.outputValue as BigDecimal?)?.setScale(9)) }
|
||||
|
||||
@JvmStatic fun float64() = float64.toArgs()
|
||||
@JvmStatic fun numeric38_9() = numeric38_9.toArgs()
|
||||
}
|
||||
|
||||
const val SIMPLE_TIMESTAMP = "simple timestamp"
|
||||
const val UNIX_EPOCH = "unix epoch"
|
||||
const val MINIMUM_TIMESTAMP = "minimum timestamp"
|
||||
const val MAXIMUM_TIMESTAMP = "maximum timestamp"
|
||||
const val OUT_OF_RANGE_TIMESTAMP = "out of range timestamp"
|
||||
const val HIGH_PRECISION_TIMESTAMP = "high-precision timestamp"
|
||||
|
||||
object DataCoercionTimestampTzFixtures {
|
||||
/**
|
||||
* Many warehouses support timestamps between years 0001 - 9999.
|
||||
*
|
||||
* Depending on the exact warehouse, you may need to tweak the precision on some values. For
|
||||
* example, Snowflake supports nanoseconds-precision timestamps (9 decimal points), but Bigquery
|
||||
* only supports microseconds-precision (6 decimal points). Bigquery would probably do something
|
||||
* like:
|
||||
* ```kotlin
|
||||
* DataCoercionNumberFixtures.traditionalWarehouse
|
||||
* .map {
|
||||
* when (it.name) {
|
||||
* "maximum AD timestamp" -> it.copy(
|
||||
* inputValue = TimestampWithTimezoneValue("9999-12-31T23:59:59.999999Z"),
|
||||
* outputValue = OffsetDateTime.parse("9999-12-31T23:59:59.999999Z"),
|
||||
* changeReason = Reason.DESTINATION_FIELD_SIZE_LIMITATION,
|
||||
* )
|
||||
* "high-precision timestamp" -> it.copy(
|
||||
* outputValue = OffsetDateTime.parse("2025-01-23T01:01:00.123456Z"),
|
||||
* changeReason = Reason.DESTINATION_FIELD_SIZE_LIMITATION,
|
||||
* )
|
||||
* }
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
val commonWarehouse =
|
||||
listOf(
|
||||
case(NULL, NullValue, null),
|
||||
case(
|
||||
SIMPLE_TIMESTAMP,
|
||||
TimestampWithTimezoneValue("2025-01-23T12:34:56.789Z"),
|
||||
"2025-01-23T12:34:56.789Z",
|
||||
),
|
||||
case(
|
||||
UNIX_EPOCH,
|
||||
TimestampWithTimezoneValue("1970-01-01T00:00:00Z"),
|
||||
"1970-01-01T00:00:00Z",
|
||||
),
|
||||
case(
|
||||
MINIMUM_TIMESTAMP,
|
||||
TimestampWithTimezoneValue("0001-01-01T00:00:00Z"),
|
||||
"0001-01-01T00:00:00Z",
|
||||
),
|
||||
case(
|
||||
MAXIMUM_TIMESTAMP,
|
||||
TimestampWithTimezoneValue("9999-12-31T23:59:59.999999999Z"),
|
||||
"9999-12-31T23:59:59.999999999Z",
|
||||
),
|
||||
case(
|
||||
OUT_OF_RANGE_TIMESTAMP,
|
||||
TimestampWithTimezoneValue(odt("10000-01-01T00:00Z")),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION,
|
||||
),
|
||||
case(
|
||||
HIGH_PRECISION_TIMESTAMP,
|
||||
TimestampWithTimezoneValue("2025-01-23T01:01:00.123456789Z"),
|
||||
"2025-01-23T01:01:00.123456789Z",
|
||||
),
|
||||
)
|
||||
|
||||
@JvmStatic fun commonWarehouse() = commonWarehouse.toArgs()
|
||||
}
|
||||
|
||||
object DataCoercionTimestampNtzFixtures {
|
||||
/** See [DataCoercionTimestampTzFixtures.commonWarehouse] for explanation */
|
||||
val commonWarehouse =
|
||||
listOf(
|
||||
case(NULL, NullValue, null),
|
||||
case(
|
||||
SIMPLE_TIMESTAMP,
|
||||
TimestampWithoutTimezoneValue("2025-01-23T12:34:56.789"),
|
||||
"2025-01-23T12:34:56.789",
|
||||
),
|
||||
case(
|
||||
UNIX_EPOCH,
|
||||
TimestampWithoutTimezoneValue("1970-01-01T00:00:00"),
|
||||
"1970-01-01T00:00:00",
|
||||
),
|
||||
case(
|
||||
MINIMUM_TIMESTAMP,
|
||||
TimestampWithoutTimezoneValue("0001-01-01T00:00:00"),
|
||||
"0001-01-01T00:00:00",
|
||||
),
|
||||
case(
|
||||
MAXIMUM_TIMESTAMP,
|
||||
TimestampWithoutTimezoneValue("9999-12-31T23:59:59.999999999"),
|
||||
"9999-12-31T23:59:59.999999999",
|
||||
),
|
||||
case(
|
||||
OUT_OF_RANGE_TIMESTAMP,
|
||||
TimestampWithoutTimezoneValue(ldt("10000-01-01T00:00")),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION,
|
||||
),
|
||||
case(
|
||||
HIGH_PRECISION_TIMESTAMP,
|
||||
TimestampWithoutTimezoneValue("2025-01-23T01:01:00.123456789"),
|
||||
"2025-01-23T01:01:00.123456789",
|
||||
),
|
||||
)
|
||||
|
||||
@JvmStatic fun commonWarehouse() = commonWarehouse.toArgs()
|
||||
}
|
||||
|
||||
const val MIDNIGHT = "midnight"
|
||||
const val MAX_TIME = "max time"
|
||||
const val HIGH_NOON = "high noon"
|
||||
|
||||
object DataCoercionTimeTzFixtures {
|
||||
val timetz =
|
||||
listOf(
|
||||
case(NULL, NullValue, null),
|
||||
case(MIDNIGHT, TimeWithTimezoneValue("00:00Z"), "00:00Z"),
|
||||
case(MAX_TIME, TimeWithTimezoneValue("23:59:59.999999999Z"), "23:59:59.999999999Z"),
|
||||
case(HIGH_NOON, TimeWithTimezoneValue("12:00Z"), "12:00Z"),
|
||||
)
|
||||
|
||||
@JvmStatic fun timetz() = timetz.toArgs()
|
||||
}
|
||||
|
||||
object DataCoercionTimeNtzFixtures {
|
||||
val timentz =
|
||||
listOf(
|
||||
case(NULL, NullValue, null),
|
||||
case(MIDNIGHT, TimeWithoutTimezoneValue("00:00"), "00:00"),
|
||||
case(MAX_TIME, TimeWithoutTimezoneValue("23:59:59.999999999"), "23:59:59.999999999"),
|
||||
case(HIGH_NOON, TimeWithoutTimezoneValue("12:00"), "12:00"),
|
||||
)
|
||||
|
||||
@JvmStatic fun timentz() = timentz.toArgs()
|
||||
}
|
||||
|
||||
object DataCoercionDateFixtures {
|
||||
val commonWarehouse =
|
||||
listOf(
|
||||
case(NULL, NullValue, null),
|
||||
case(
|
||||
SIMPLE_TIMESTAMP,
|
||||
DateValue("2025-01-23"),
|
||||
"2025-01-23",
|
||||
),
|
||||
case(
|
||||
UNIX_EPOCH,
|
||||
DateValue("1970-01-01"),
|
||||
"1970-01-01",
|
||||
),
|
||||
case(
|
||||
MINIMUM_TIMESTAMP,
|
||||
DateValue("0001-01-01"),
|
||||
"0001-01-01",
|
||||
),
|
||||
case(
|
||||
MAXIMUM_TIMESTAMP,
|
||||
DateValue("9999-12-31"),
|
||||
"9999-12-31",
|
||||
),
|
||||
case(
|
||||
OUT_OF_RANGE_TIMESTAMP,
|
||||
DateValue(date("10000-01-01")),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION,
|
||||
),
|
||||
)
|
||||
|
||||
@JvmStatic fun commonWarehouse() = commonWarehouse.toArgs()
|
||||
}
|
||||
|
||||
object DataCoercionStringFixtures {
|
||||
const val EMPTY_STRING = "empty string"
|
||||
const val SHORT_STRING = "short string"
|
||||
const val LONG_STRING = "long string"
|
||||
const val SPECIAL_CHARS_STRING = "special chars string"
|
||||
|
||||
val strings =
|
||||
listOf(
|
||||
case(NULL, NullValue, null),
|
||||
case(EMPTY_STRING, StringValue(""), ""),
|
||||
case(SHORT_STRING, StringValue("foo"), "foo"),
|
||||
// Implementers may override this to test their destination-specific limits.
|
||||
// The default value is 8MB + 1 byte (slightly longer than snowflake's varchar limit).
|
||||
case(
|
||||
LONG_STRING,
|
||||
StringValue("a".repeat(16777216 + 1)),
|
||||
null,
|
||||
Reason.DESTINATION_FIELD_SIZE_LIMITATION
|
||||
),
|
||||
case(
|
||||
SPECIAL_CHARS_STRING,
|
||||
StringValue("`~!@#$%^&*()-=_+[]\\{}|o'O\",./<>?)Δ⅀↑∀"),
|
||||
"`~!@#$%^&*()-=_+[]\\{}|o'O\",./<>?)Δ⅀↑∀"
|
||||
),
|
||||
)
|
||||
|
||||
@JvmStatic fun strings() = strings.toArgs()
|
||||
}
|
||||
|
||||
object DataCoercionObjectFixtures {
|
||||
const val EMPTY_OBJECT = "empty object"
|
||||
const val NORMAL_OBJECT = "normal object"
|
||||
|
||||
val objects =
|
||||
listOf(
|
||||
case(NULL, NullValue, null),
|
||||
case(EMPTY_OBJECT, ObjectValue(linkedMapOf()), emptyMap<String, Any?>()),
|
||||
case(
|
||||
NORMAL_OBJECT,
|
||||
ObjectValue(linkedMapOf("foo" to StringValue("bar"))),
|
||||
mapOf("foo" to "bar")
|
||||
),
|
||||
)
|
||||
|
||||
val stringifiedObjects =
|
||||
objects.map { fixture ->
|
||||
fixture.copy(outputValue = fixture.outputValue?.serializeToString())
|
||||
}
|
||||
|
||||
@JvmStatic fun objects() = objects.toArgs()
|
||||
|
||||
@JvmStatic fun stringifiedObjects() = stringifiedObjects.toArgs()
|
||||
}
|
||||
|
||||
object DataCoercionArrayFixtures {
|
||||
const val EMPTY_ARRAY = "empty array"
|
||||
const val NORMAL_ARRAY = "normal array"
|
||||
|
||||
val arrays =
|
||||
listOf(
|
||||
case(NULL, NullValue, null),
|
||||
case(EMPTY_ARRAY, ArrayValue(emptyList()), emptyList<Any?>()),
|
||||
case(NORMAL_ARRAY, ArrayValue(listOf(StringValue("foo"))), listOf("foo")),
|
||||
)
|
||||
|
||||
val stringifiedArrays =
|
||||
arrays.map { fixture ->
|
||||
fixture.copy(outputValue = fixture.outputValue?.serializeToString())
|
||||
}
|
||||
|
||||
@JvmStatic fun arrays() = arrays.toArgs()
|
||||
|
||||
@JvmStatic fun stringifiedArrays() = stringifiedArrays.toArgs()
|
||||
}
|
||||
|
||||
const val UNION_INT_VALUE = "int value"
|
||||
const val UNION_OBJ_VALUE = "object value"
|
||||
const val UNION_STR_VALUE = "string value"
|
||||
|
||||
object DataCoercionUnionFixtures {
|
||||
val unions =
|
||||
listOf(
|
||||
case(NULL, NullValue, null),
|
||||
case(UNION_INT_VALUE, IntegerValue(42), 42L),
|
||||
case(UNION_STR_VALUE, StringValue("foo"), "foo"),
|
||||
case(
|
||||
UNION_OBJ_VALUE,
|
||||
ObjectValue(linkedMapOf("foo" to StringValue("bar"))),
|
||||
mapOf("foo" to "bar")
|
||||
),
|
||||
)
|
||||
|
||||
val stringifiedUnions =
|
||||
unions.map { fixture ->
|
||||
fixture.copy(outputValue = fixture.outputValue?.serializeToString())
|
||||
}
|
||||
|
||||
@JvmStatic fun unions() = unions.toArgs()
|
||||
|
||||
@JvmStatic fun stringifiedUnions() = stringifiedUnions.toArgs()
|
||||
}
|
||||
|
||||
object DataCoercionLegacyUnionFixtures {
|
||||
val unions =
|
||||
listOf(
|
||||
case(NULL, NullValue, null),
|
||||
// Legacy union of int x object will select object, and you can't write an int to an
|
||||
// object column.
|
||||
// So we should null it out.
|
||||
case(UNION_INT_VALUE, IntegerValue(42), null, Reason.DESTINATION_TYPECAST_ERROR),
|
||||
// Similarly, we should null out strings.
|
||||
case(UNION_STR_VALUE, StringValue("foo"), "foo"),
|
||||
// But objects can be written as objects, so retain this value.
|
||||
case(
|
||||
UNION_OBJ_VALUE,
|
||||
ObjectValue(linkedMapOf("foo" to StringValue("bar"))),
|
||||
mapOf("foo" to "bar")
|
||||
),
|
||||
)
|
||||
|
||||
val stringifiedUnions =
|
||||
DataCoercionUnionFixtures.unions.map { fixture ->
|
||||
fixture.copy(outputValue = fixture.outputValue?.serializeToString())
|
||||
}
|
||||
|
||||
@JvmStatic fun unions() = unions.toArgs()
|
||||
|
||||
@JvmStatic fun stringifiedUnions() = DataCoercionUnionFixtures.stringifiedUnions.toArgs()
|
||||
}
|
||||
|
||||
// This is pretty much identical to UnionFixtures, but separating them in case we need to add
|
||||
// different test cases for either of them.
|
||||
object DataCoercionUnknownFixtures {
|
||||
const val INT_VALUE = "integer value"
|
||||
const val STR_VALUE = "string value"
|
||||
const val OBJ_VALUE = "object value"
|
||||
|
||||
val unknowns =
|
||||
listOf(
|
||||
case(NULL, NullValue, null),
|
||||
case(INT_VALUE, IntegerValue(42), 42L),
|
||||
case(STR_VALUE, StringValue("foo"), "foo"),
|
||||
case(
|
||||
OBJ_VALUE,
|
||||
ObjectValue(linkedMapOf("foo" to StringValue("bar"))),
|
||||
mapOf("foo" to "bar")
|
||||
),
|
||||
)
|
||||
|
||||
val stringifiedUnknowns =
|
||||
unknowns.map { fixture ->
|
||||
fixture.copy(outputValue = fixture.outputValue?.serializeToString())
|
||||
}
|
||||
|
||||
@JvmStatic fun unknowns() = unknowns.toArgs()
|
||||
|
||||
@JvmStatic fun stringifiedUnknowns() = stringifiedUnknowns.toArgs()
|
||||
}
|
||||
|
||||
fun List<DataCoercionTestCase>.toArgs(): List<Arguments> =
|
||||
this.map { Arguments.argumentSet(it.name, it.inputValue, it.outputValue, it.changeReason) }
|
||||
.toList()
|
||||
|
||||
/**
|
||||
* Utility method to use the BigDecimal constructor (supports exponential notation like `1e38`) to
|
||||
* construct a BigInteger.
|
||||
*/
|
||||
fun bigint(str: String): BigInteger = BigDecimal(str).toBigIntegerExact()
|
||||
|
||||
/** Shorthand utility method to construct a bigint from a long */
|
||||
fun bigint(long: Long): BigInteger = BigInteger.valueOf(long)
|
||||
|
||||
fun bigdec(str: String): BigDecimal = BigDecimal(str)
|
||||
|
||||
fun bigdec(double: Double): BigDecimal = BigDecimal.valueOf(double)
|
||||
|
||||
fun bigdec(int: Int): BigDecimal = BigDecimal.valueOf(int.toDouble())
|
||||
|
||||
fun odt(str: String): OffsetDateTime = OffsetDateTime.parse(str, dateTimeFormatter)
|
||||
|
||||
fun ldt(str: String): LocalDateTime = LocalDateTime.parse(str, dateTimeFormatter)
|
||||
|
||||
fun date(str: String): LocalDate = LocalDate.parse(str, dateFormatter)
|
||||
|
||||
// The default java.time.*.parse() behavior only accepts up to 4-digit years.
|
||||
// Build a custom formatter to handle larger years.
|
||||
val dateFormatter =
|
||||
DateTimeFormatterBuilder()
|
||||
// java.time.* supports up to 9-digit years
|
||||
.appendValue(ChronoField.YEAR, 1, 9, SignStyle.NORMAL)
|
||||
.appendLiteral('-')
|
||||
.appendValue(ChronoField.MONTH_OF_YEAR)
|
||||
.appendLiteral('-')
|
||||
.appendValue(ChronoField.DAY_OF_MONTH)
|
||||
.toFormatter()
|
||||
|
||||
val dateTimeFormatter =
|
||||
DateTimeFormatterBuilder()
|
||||
.append(dateFormatter)
|
||||
.appendLiteral('T')
|
||||
// Accepts strings with/without an offset, so we can use this formatter
|
||||
// for both timestamp with and without timezone
|
||||
.append(DateTimeFormatter.ISO_TIME)
|
||||
.toFormatter()
|
||||
|
||||
/**
|
||||
* Represents a single data coercion test case. You probably want to use [case] as a shorthand
|
||||
* constructor.
|
||||
*
|
||||
* @param name A short human-readable name for the test. Primarily useful for tests where
|
||||
* [inputValue] is either very long, or otherwise hard to read.
|
||||
* @param inputValue The value to pass into [ValueCoercer.validate]
|
||||
* @param outputValue The value that we expect to read back from the destination. Should be
|
||||
* basically equivalent to the output of [ValueCoercer.validate]
|
||||
* @param changeReason If `validate` returns Truncate/Nullify, the reason for that
|
||||
* truncation/nullification. If `validate` returns Valid, this should be null.
|
||||
*/
|
||||
data class DataCoercionTestCase(
|
||||
val name: String,
|
||||
val inputValue: AirbyteValue,
|
||||
val outputValue: Any?,
|
||||
val changeReason: Reason? = null,
|
||||
)
|
||||
|
||||
fun case(
|
||||
name: String,
|
||||
inputValue: AirbyteValue,
|
||||
outputValue: Any?,
|
||||
changeReason: Reason? = null,
|
||||
) = DataCoercionTestCase(name, inputValue, outputValue, changeReason)
|
||||
|
||||
const val NULL = "null"
|
||||
@@ -0,0 +1,369 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.cdk.load.component
|
||||
|
||||
import io.airbyte.cdk.load.data.AirbyteValue
|
||||
import io.airbyte.cdk.load.data.ArrayType
|
||||
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
|
||||
import io.airbyte.cdk.load.data.BooleanType
|
||||
import io.airbyte.cdk.load.data.BooleanValue
|
||||
import io.airbyte.cdk.load.data.DateType
|
||||
import io.airbyte.cdk.load.data.FieldType
|
||||
import io.airbyte.cdk.load.data.IntegerType
|
||||
import io.airbyte.cdk.load.data.NumberType
|
||||
import io.airbyte.cdk.load.data.ObjectType
|
||||
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
|
||||
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
|
||||
import io.airbyte.cdk.load.data.StringType
|
||||
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
|
||||
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
|
||||
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
|
||||
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
|
||||
import io.airbyte.cdk.load.data.UnionType
|
||||
import io.airbyte.cdk.load.data.UnknownType
|
||||
import io.airbyte.cdk.load.dataflow.transform.ValueCoercer
|
||||
import io.airbyte.cdk.load.message.Meta
|
||||
import io.airbyte.cdk.load.schema.TableSchemaFactory
|
||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
||||
import io.airbyte.cdk.load.util.Jsons
|
||||
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Reason
|
||||
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
|
||||
import kotlinx.coroutines.test.runTest
|
||||
|
||||
/**
|
||||
* The tests in this class are designed to reference the parameters defined in
|
||||
* `DataCoercionFixtures.kt`. For example, you might annotate [`handle integer values`] with
|
||||
* `@MethodSource("io.airbyte.cdk.load.component.DataCoercionIntegerFixtures#int32")`. See each
|
||||
* fixture class for explanations of what behavior they are exercising.
|
||||
*
|
||||
* Note that this class _only_ exercises [ValueCoercer.validate]. You should write separate unit
|
||||
* tests for [ValueCoercer.map]. For now, the `map` function is primarily intended for transforming
|
||||
* `UnionType` fields into other types (typically `StringType`), at which point your `validate`
|
||||
* implementation should be able to handle any StringValue (regardless of whether it was originally
|
||||
* a StringType or UnionType).
|
||||
*/
|
||||
@MicronautTest(environments = ["component"], resolveParameters = false)
|
||||
interface DataCoercionSuite {
|
||||
val coercer: ValueCoercer
|
||||
val airbyteMetaColumnMapping: Map<String, String>
|
||||
get() = Meta.COLUMN_NAMES.associateWith { it }
|
||||
val columnNameMapping: ColumnNameMapping
|
||||
get() = ColumnNameMapping(mapOf("test" to "test"))
|
||||
|
||||
val opsClient: TableOperationsClient
|
||||
val testClient: TestTableOperationsClient
|
||||
val schemaFactory: TableSchemaFactory
|
||||
|
||||
val harness: TableOperationsTestHarness
|
||||
get() =
|
||||
TableOperationsTestHarness(
|
||||
opsClient,
|
||||
testClient,
|
||||
schemaFactory,
|
||||
airbyteMetaColumnMapping
|
||||
)
|
||||
|
||||
/** Fixtures are defined in [DataCoercionIntegerFixtures]. */
|
||||
fun `handle integer values`(
|
||||
inputValue: AirbyteValue,
|
||||
expectedValue: Any?,
|
||||
expectedChangeReason: Reason?
|
||||
) = runTest {
|
||||
harness.testValueCoercion(
|
||||
coercer,
|
||||
columnNameMapping,
|
||||
FieldType(IntegerType, nullable = true),
|
||||
inputValue,
|
||||
expectedValue,
|
||||
expectedChangeReason,
|
||||
)
|
||||
}
|
||||
|
||||
/** Fixtures are defined in [DataCoercionNumberFixtures]. */
|
||||
fun `handle number values`(
|
||||
inputValue: AirbyteValue,
|
||||
expectedValue: Any?,
|
||||
expectedChangeReason: Reason?
|
||||
) = runTest {
|
||||
harness.testValueCoercion(
|
||||
coercer,
|
||||
columnNameMapping,
|
||||
FieldType(NumberType, nullable = true),
|
||||
inputValue,
|
||||
expectedValue,
|
||||
expectedChangeReason,
|
||||
)
|
||||
}
|
||||
|
||||
/** Fixtures are defined in [DataCoercionTimestampTzFixtures]. */
|
||||
fun `handle timestamptz values`(
|
||||
inputValue: AirbyteValue,
|
||||
expectedValue: Any?,
|
||||
expectedChangeReason: Reason?
|
||||
) = runTest {
|
||||
harness.testValueCoercion(
|
||||
coercer,
|
||||
columnNameMapping,
|
||||
FieldType(TimestampTypeWithTimezone, nullable = true),
|
||||
inputValue,
|
||||
expectedValue,
|
||||
expectedChangeReason,
|
||||
)
|
||||
}
|
||||
|
||||
/** Fixtures are defined in [DataCoercionTimestampNtzFixtures]. */
|
||||
fun `handle timestampntz values`(
|
||||
inputValue: AirbyteValue,
|
||||
expectedValue: Any?,
|
||||
expectedChangeReason: Reason?
|
||||
) = runTest {
|
||||
harness.testValueCoercion(
|
||||
coercer,
|
||||
columnNameMapping,
|
||||
FieldType(TimestampTypeWithoutTimezone, nullable = true),
|
||||
inputValue,
|
||||
expectedValue,
|
||||
expectedChangeReason,
|
||||
)
|
||||
}
|
||||
|
||||
/** Fixtures are defined in [DataCoercionTimeTzFixtures]. */
|
||||
fun `handle timetz values`(
|
||||
inputValue: AirbyteValue,
|
||||
expectedValue: Any?,
|
||||
expectedChangeReason: Reason?
|
||||
) = runTest {
|
||||
harness.testValueCoercion(
|
||||
coercer,
|
||||
columnNameMapping,
|
||||
FieldType(TimeTypeWithTimezone, nullable = true),
|
||||
inputValue,
|
||||
expectedValue,
|
||||
expectedChangeReason,
|
||||
)
|
||||
}
|
||||
|
||||
/** Fixtures are defined in [DataCoercionTimeNtzFixtures]. */
|
||||
fun `handle timentz values`(
|
||||
inputValue: AirbyteValue,
|
||||
expectedValue: Any?,
|
||||
expectedChangeReason: Reason?
|
||||
) = runTest {
|
||||
harness.testValueCoercion(
|
||||
coercer,
|
||||
columnNameMapping,
|
||||
FieldType(TimeTypeWithoutTimezone, nullable = true),
|
||||
inputValue,
|
||||
expectedValue,
|
||||
expectedChangeReason,
|
||||
)
|
||||
}
|
||||
|
||||
/** Fixtures are defined in [DataCoercionDateFixtures]. */
|
||||
fun `handle date values`(
|
||||
inputValue: AirbyteValue,
|
||||
expectedValue: Any?,
|
||||
expectedChangeReason: Reason?
|
||||
) = runTest {
|
||||
harness.testValueCoercion(
|
||||
coercer,
|
||||
columnNameMapping,
|
||||
FieldType(DateType, nullable = true),
|
||||
inputValue,
|
||||
expectedValue,
|
||||
expectedChangeReason,
|
||||
)
|
||||
}
|
||||
|
||||
/** No fixtures, hardcoded to just write `true` */
|
||||
fun `handle bool values`(expectedValue: Any?) = runTest {
|
||||
harness.testValueCoercion(
|
||||
coercer,
|
||||
columnNameMapping,
|
||||
FieldType(BooleanType, nullable = true),
|
||||
// Just test on `true` and assume `false` also works
|
||||
BooleanValue(true),
|
||||
expectedValue,
|
||||
// If your destination is nulling/truncating booleans... that's almost definitely a bug
|
||||
expectedChangeReason = null,
|
||||
)
|
||||
}
|
||||
|
||||
/** Fixtures are defined in [DataCoercionStringFixtures]. */
|
||||
fun `handle string values`(
|
||||
inputValue: AirbyteValue,
|
||||
expectedValue: Any?,
|
||||
expectedChangeReason: Reason?
|
||||
) = runTest {
|
||||
harness.testValueCoercion(
|
||||
coercer,
|
||||
columnNameMapping,
|
||||
FieldType(StringType, nullable = true),
|
||||
inputValue,
|
||||
expectedValue,
|
||||
expectedChangeReason,
|
||||
)
|
||||
}
|
||||
|
||||
/** Fixtures are defined in [DataCoercionObjectFixtures]. */
|
||||
fun `handle object values`(
|
||||
inputValue: AirbyteValue,
|
||||
expectedValue: Any?,
|
||||
expectedChangeReason: Reason?
|
||||
) = runTest {
|
||||
harness.testValueCoercion(
|
||||
coercer,
|
||||
columnNameMapping,
|
||||
FieldType(
|
||||
ObjectType(linkedMapOf("foo" to FieldType(StringType, true))),
|
||||
nullable = true
|
||||
),
|
||||
inputValue,
|
||||
expectedValue,
|
||||
expectedChangeReason,
|
||||
)
|
||||
}
|
||||
|
||||
/** Fixtures are defined in [DataCoercionObjectFixtures]. */
|
||||
fun `handle empty object values`(
|
||||
inputValue: AirbyteValue,
|
||||
expectedValue: Any?,
|
||||
expectedChangeReason: Reason?
|
||||
) = runTest {
|
||||
harness.testValueCoercion(
|
||||
coercer,
|
||||
columnNameMapping,
|
||||
FieldType(ObjectTypeWithEmptySchema, nullable = true),
|
||||
inputValue,
|
||||
expectedValue,
|
||||
expectedChangeReason,
|
||||
)
|
||||
}
|
||||
|
||||
/** Fixtures are defined in [DataCoercionObjectFixtures]. */
|
||||
fun `handle schemaless object values`(
|
||||
inputValue: AirbyteValue,
|
||||
expectedValue: Any?,
|
||||
expectedChangeReason: Reason?
|
||||
) = runTest {
|
||||
harness.testValueCoercion(
|
||||
coercer,
|
||||
columnNameMapping,
|
||||
FieldType(ObjectTypeWithoutSchema, nullable = true),
|
||||
inputValue,
|
||||
expectedValue,
|
||||
expectedChangeReason,
|
||||
)
|
||||
}
|
||||
|
||||
/** Fixtures are defined in [DataCoercionArrayFixtures]. */
|
||||
fun `handle array values`(
|
||||
inputValue: AirbyteValue,
|
||||
expectedValue: Any?,
|
||||
expectedChangeReason: Reason?
|
||||
) = runTest {
|
||||
harness.testValueCoercion(
|
||||
coercer,
|
||||
columnNameMapping,
|
||||
FieldType(ArrayType(FieldType(StringType, true)), nullable = true),
|
||||
inputValue,
|
||||
expectedValue,
|
||||
expectedChangeReason,
|
||||
)
|
||||
}
|
||||
|
||||
/** Fixtures are defined in [DataCoercionArrayFixtures]. */
|
||||
fun `handle schemaless array values`(
|
||||
inputValue: AirbyteValue,
|
||||
expectedValue: Any?,
|
||||
expectedChangeReason: Reason?
|
||||
) = runTest {
|
||||
harness.testValueCoercion(
|
||||
coercer,
|
||||
columnNameMapping,
|
||||
FieldType(ArrayTypeWithoutSchema, nullable = true),
|
||||
inputValue,
|
||||
expectedValue,
|
||||
expectedChangeReason,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* All destinations should implement this, even if your destination is supporting legacy unions.
|
||||
*
|
||||
* Fixtures are defined in [DataCoercionUnionFixtures].
|
||||
*/
|
||||
fun `handle union values`(
|
||||
inputValue: AirbyteValue,
|
||||
expectedValue: Any?,
|
||||
expectedChangeReason: Reason?
|
||||
) = runTest {
|
||||
harness.testValueCoercion(
|
||||
coercer,
|
||||
columnNameMapping,
|
||||
FieldType(
|
||||
UnionType(
|
||||
setOf(
|
||||
ObjectType(linkedMapOf("foo" to FieldType(StringType, true))),
|
||||
IntegerType,
|
||||
StringType,
|
||||
),
|
||||
isLegacyUnion = false
|
||||
),
|
||||
nullable = true
|
||||
),
|
||||
inputValue,
|
||||
expectedValue,
|
||||
expectedChangeReason,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Only legacy destinations that are maintaining "legacy" union behavior should implement this
|
||||
* test. If you're not sure, check whether your `application-connector.yaml` includes a
|
||||
* `airbyte.destination.core.types.unions: LEGACY` property.
|
||||
*
|
||||
* Fixtures are defined in [DataCoercionLegacyUnionFixtures].
|
||||
*/
|
||||
fun `handle legacy union values`(
|
||||
inputValue: AirbyteValue,
|
||||
expectedValue: Any?,
|
||||
expectedChangeReason: Reason?
|
||||
) = runTest {
|
||||
harness.testValueCoercion(
|
||||
coercer,
|
||||
columnNameMapping,
|
||||
FieldType(
|
||||
UnionType(
|
||||
setOf(
|
||||
ObjectType(linkedMapOf("foo" to FieldType(StringType, true))),
|
||||
IntegerType,
|
||||
StringType,
|
||||
),
|
||||
isLegacyUnion = true
|
||||
),
|
||||
nullable = true
|
||||
),
|
||||
inputValue,
|
||||
expectedValue,
|
||||
expectedChangeReason,
|
||||
)
|
||||
}
|
||||
|
||||
fun `handle unknown values`(
|
||||
inputValue: AirbyteValue,
|
||||
expectedValue: Any?,
|
||||
expectedChangeReason: Reason?
|
||||
) = runTest {
|
||||
harness.testValueCoercion(
|
||||
coercer,
|
||||
columnNameMapping,
|
||||
FieldType(UnknownType(Jsons.readTree(("""{"type": "potato"}"""))), nullable = true),
|
||||
inputValue,
|
||||
expectedValue,
|
||||
expectedChangeReason,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -23,6 +23,7 @@ import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
|
||||
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
|
||||
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
|
||||
import io.airbyte.cdk.load.data.TimestampWithTimezoneValue
|
||||
import io.airbyte.cdk.load.data.UnionType
|
||||
import io.airbyte.cdk.load.data.UnknownType
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
|
||||
@@ -84,6 +85,18 @@ object TableOperationsFixtures {
|
||||
"array" to FieldType(ArrayType(FieldType(StringType, true)), true),
|
||||
"object" to
|
||||
FieldType(ObjectType(linkedMapOf("key" to FieldType(StringType, true))), true),
|
||||
"union" to
|
||||
FieldType(
|
||||
UnionType(setOf(StringType, IntegerType), isLegacyUnion = false),
|
||||
true
|
||||
),
|
||||
// Most destinations just ignore the isLegacyUnion flag, which is totally fine.
|
||||
// This is here for the small set of connectors that respect it.
|
||||
"legacy_union" to
|
||||
FieldType(
|
||||
UnionType(setOf(StringType, IntegerType), isLegacyUnion = true),
|
||||
true
|
||||
),
|
||||
"unknown" to FieldType(UnknownType(Jsons.readTree("""{"type": "potato"}""")), true),
|
||||
),
|
||||
)
|
||||
@@ -101,6 +114,8 @@ object TableOperationsFixtures {
|
||||
"time_ntz" to "time_ntz",
|
||||
"array" to "array",
|
||||
"object" to "object",
|
||||
"union" to "union",
|
||||
"legacy_union" to "legacy_union",
|
||||
"unknown" to "unknown",
|
||||
)
|
||||
)
|
||||
@@ -714,6 +729,11 @@ object TableOperationsFixtures {
|
||||
return map { record -> record.mapKeys { (k, _) -> totalMapping.invert()[k] ?: k } }
|
||||
}
|
||||
|
||||
fun <V> List<Map<String, V>>.removeAirbyteColumns(
|
||||
airbyteMetaColumnMapping: Map<String, String>
|
||||
): List<Map<String, V>> =
|
||||
this.map { rec -> rec.filter { !airbyteMetaColumnMapping.containsValue(it.key) } }
|
||||
|
||||
fun <V> List<Map<String, V>>.removeNulls() =
|
||||
this.map { record -> record.filterValues { it != null } }
|
||||
|
||||
|
||||
@@ -58,7 +58,8 @@ interface TableOperationsSuite {
|
||||
get() = Meta.COLUMN_NAMES.associateWith { it }
|
||||
|
||||
private val harness: TableOperationsTestHarness
|
||||
get() = TableOperationsTestHarness(client, testClient, airbyteMetaColumnMapping)
|
||||
get() =
|
||||
TableOperationsTestHarness(client, testClient, schemaFactory, airbyteMetaColumnMapping)
|
||||
|
||||
/** Tests basic database connectivity by pinging the database. */
|
||||
fun `connect to database`() = runTest { assertDoesNotThrow { testClient.ping() } }
|
||||
@@ -606,7 +607,7 @@ interface TableOperationsSuite {
|
||||
val targetTableSchema =
|
||||
schemaFactory.make(
|
||||
targetTable,
|
||||
Fixtures.TEST_INTEGER_SCHEMA.properties,
|
||||
Fixtures.ID_TEST_WITH_CDC_SCHEMA.properties,
|
||||
Dedupe(
|
||||
primaryKey = listOf(listOf(Fixtures.ID_FIELD)),
|
||||
cursor = listOf(Fixtures.TEST_FIELD),
|
||||
|
||||
@@ -4,11 +4,24 @@
|
||||
|
||||
package io.airbyte.cdk.load.component
|
||||
|
||||
import io.airbyte.cdk.load.command.Append
|
||||
import io.airbyte.cdk.load.command.DestinationStream
|
||||
import io.airbyte.cdk.load.component.TableOperationsFixtures.inputRecord
|
||||
import io.airbyte.cdk.load.component.TableOperationsFixtures.insertRecords
|
||||
import io.airbyte.cdk.load.component.TableOperationsFixtures.removeAirbyteColumns
|
||||
import io.airbyte.cdk.load.component.TableOperationsFixtures.removeNulls
|
||||
import io.airbyte.cdk.load.component.TableOperationsFixtures.reverseColumnNameMapping
|
||||
import io.airbyte.cdk.load.data.AirbyteValue
|
||||
import io.airbyte.cdk.load.data.EnrichedAirbyteValue
|
||||
import io.airbyte.cdk.load.data.FieldType
|
||||
import io.airbyte.cdk.load.data.NullValue
|
||||
import io.airbyte.cdk.load.data.ObjectType
|
||||
import io.airbyte.cdk.load.dataflow.transform.ValidationResult
|
||||
import io.airbyte.cdk.load.dataflow.transform.ValueCoercer
|
||||
import io.airbyte.cdk.load.schema.TableSchemaFactory
|
||||
import io.airbyte.cdk.load.schema.model.TableName
|
||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
||||
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Reason
|
||||
import io.github.oshai.kotlinlogging.KotlinLogging
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
|
||||
@@ -21,6 +34,7 @@ private val log = KotlinLogging.logger {}
|
||||
class TableOperationsTestHarness(
|
||||
private val client: TableOperationsClient,
|
||||
private val testClient: TestTableOperationsClient,
|
||||
private val schemaFactory: TableSchemaFactory,
|
||||
private val airbyteMetaColumnMapping: Map<String, String>,
|
||||
) {
|
||||
|
||||
@@ -100,8 +114,77 @@ class TableOperationsTestHarness(
|
||||
/** Reads records from a table, filtering out Meta columns. */
|
||||
suspend fun readTableWithoutMetaColumns(tableName: TableName): List<Map<String, Any>> {
|
||||
val tableRead = testClient.readTable(tableName)
|
||||
return tableRead.map { rec ->
|
||||
rec.filter { !airbyteMetaColumnMapping.containsValue(it.key) }
|
||||
return tableRead.removeAirbyteColumns(airbyteMetaColumnMapping)
|
||||
}
|
||||
|
||||
/** Apply the coercer to a value and verify that we can write the coerced value correctly */
|
||||
suspend fun testValueCoercion(
|
||||
coercer: ValueCoercer,
|
||||
columnNameMapping: ColumnNameMapping,
|
||||
fieldType: FieldType,
|
||||
inputValue: AirbyteValue,
|
||||
expectedValue: Any?,
|
||||
expectedChangeReason: Reason?,
|
||||
) {
|
||||
val testNamespace = TableOperationsFixtures.generateTestNamespace("test")
|
||||
val tableName =
|
||||
TableOperationsFixtures.generateTestTableName("table-test-table", testNamespace)
|
||||
val schema = ObjectType(linkedMapOf("test" to fieldType))
|
||||
val tableSchema = schemaFactory.make(tableName, schema.properties, Append)
|
||||
val stream =
|
||||
TableOperationsFixtures.createStream(
|
||||
namespace = tableName.namespace,
|
||||
name = tableName.name,
|
||||
tableSchema = tableSchema,
|
||||
)
|
||||
|
||||
val inputValueAsEnrichedAirbyteValue =
|
||||
EnrichedAirbyteValue(
|
||||
inputValue,
|
||||
fieldType.type,
|
||||
"test",
|
||||
airbyteMetaField = null,
|
||||
)
|
||||
val validatedValue = coercer.validate(inputValueAsEnrichedAirbyteValue)
|
||||
val valueToInsert: AirbyteValue
|
||||
val changeReason: Reason?
|
||||
when (validatedValue) {
|
||||
is ValidationResult.ShouldNullify -> {
|
||||
valueToInsert = NullValue
|
||||
changeReason = validatedValue.reason
|
||||
}
|
||||
is ValidationResult.ShouldTruncate -> {
|
||||
valueToInsert = validatedValue.truncatedValue
|
||||
changeReason = validatedValue.reason
|
||||
}
|
||||
ValidationResult.Valid -> {
|
||||
valueToInsert = inputValue
|
||||
changeReason = null
|
||||
}
|
||||
}
|
||||
|
||||
client.createNamespace(testNamespace)
|
||||
client.createTable(stream, tableName, columnNameMapping, replace = false)
|
||||
testClient.insertRecords(
|
||||
tableName,
|
||||
columnNameMapping,
|
||||
inputRecord("test" to valueToInsert),
|
||||
)
|
||||
|
||||
val actualRecords =
|
||||
testClient
|
||||
.readTable(tableName)
|
||||
.removeAirbyteColumns(airbyteMetaColumnMapping)
|
||||
.reverseColumnNameMapping(columnNameMapping, airbyteMetaColumnMapping)
|
||||
.removeNulls()
|
||||
val actualValue = actualRecords.first()["test"]
|
||||
assertEquals(
|
||||
expectedValue,
|
||||
actualValue,
|
||||
"For input $inputValue, expected ${expectedValue.simpleClassName()}; actual value was ${actualValue.simpleClassName()}. Coercer output was $validatedValue.",
|
||||
)
|
||||
assertEquals(expectedChangeReason, changeReason)
|
||||
}
|
||||
}
|
||||
|
||||
fun Any?.simpleClassName() = this?.let { it::class.simpleName } ?: "null"
|
||||
|
||||
@@ -44,7 +44,13 @@ interface TableSchemaEvolutionSuite {
|
||||
val schemaFactory: TableSchemaFactory
|
||||
|
||||
private val harness: TableOperationsTestHarness
|
||||
get() = TableOperationsTestHarness(opsClient, testClient, airbyteMetaColumnMapping)
|
||||
get() =
|
||||
TableOperationsTestHarness(
|
||||
opsClient,
|
||||
testClient,
|
||||
schemaFactory,
|
||||
airbyteMetaColumnMapping
|
||||
)
|
||||
|
||||
/**
|
||||
* Test that the connector can correctly discover all of its own data types. This test creates a
|
||||
|
||||
@@ -1 +1 @@
|
||||
version=0.1.87
|
||||
version=0.1.91
|
||||
|
||||
@@ -10,5 +10,6 @@ CONNECTOR_PATH_PREFIXES = {
|
||||
"airbyte-integrations/connectors",
|
||||
"docs/integrations/sources",
|
||||
"docs/integrations/destinations",
|
||||
"docs/ai-agents/connectors",
|
||||
}
|
||||
MERGE_METHOD = "squash"
|
||||
|
||||
@@ -75,7 +75,7 @@ This will copy the specified connector version to your development bucket. This
|
||||
_💡 Note: A prerequisite is you have [gsutil](https://cloud.google.com/storage/docs/gsutil) installed and have run `gsutil auth login`_
|
||||
|
||||
```bash
|
||||
TARGET_BUCKET=<YOUR-DEV_BUCKET> CONNECTOR="airbyte/source-stripe" VERSION="3.17.0-dev.ea013c8741" poetry run poe copy-connector-from-prod
|
||||
TARGET_BUCKET=<YOUR-DEV_BUCKET> CONNECTOR="airbyte/source-stripe" VERSION="3.17.0-preview.ea013c8" poetry run poe copy-connector-from-prod
|
||||
```
|
||||
|
||||
### Promote Connector Version to Latest
|
||||
@@ -87,5 +87,5 @@ _💡 Note: A prerequisite is you have [gsutil](https://cloud.google.com/storage
|
||||
_⚠️ Warning: Its important to know that this will remove ANY existing files in the latest folder that are not in the versioned folder as it calls `gsutil rsync` with `-d` enabled._
|
||||
|
||||
```bash
|
||||
TARGET_BUCKET=<YOUR-DEV_BUCKET> CONNECTOR="airbyte/source-stripe" VERSION="3.17.0-dev.ea013c8741" poetry run poe promote-connector-to-latest
|
||||
TARGET_BUCKET=<YOUR-DEV_BUCKET> CONNECTOR="airbyte/source-stripe" VERSION="3.17.0-preview.ea013c8" poetry run poe promote-connector-to-latest
|
||||
```
|
||||
|
||||
@@ -28,8 +28,8 @@ def get_docker_hub_auth_token() -> str:
|
||||
|
||||
|
||||
def get_docker_hub_headers() -> Dict | None:
|
||||
if "DOCKER_HUB_USERNAME" not in os.environ or "DOCKER_HUB_PASSWORD" not in os.environ:
|
||||
# If the Docker Hub credentials are not provided, we can only anonymously call the Docker Hub API.
|
||||
if not os.environ.get("DOCKER_HUB_USERNAME") or not os.environ.get("DOCKER_HUB_PASSWORD"):
|
||||
# If the Docker Hub credentials are not provided (or are empty), we can only anonymously call the Docker Hub API.
|
||||
# This will only work for public images and lead to a lower rate limit.
|
||||
return {}
|
||||
else:
|
||||
|
||||
@@ -434,7 +434,7 @@ def generate_and_persist_registry_entry(
|
||||
bucket_name (str): The name of the GCS bucket.
|
||||
repo_metadata_file_path (pathlib.Path): The path to the spec file.
|
||||
registry_type (str): The registry type.
|
||||
docker_image_tag (str): The docker image tag associated with this release. Typically a semver string (e.g. '1.2.3'), possibly with a suffix (e.g. '1.2.3-dev.abcde12345')
|
||||
docker_image_tag (str): The docker image tag associated with this release. Typically a semver string (e.g. '1.2.3'), possibly with a suffix (e.g. '1.2.3-preview.abcde12')
|
||||
is_prerelease (bool): Whether this is a prerelease, or a main release.
|
||||
"""
|
||||
# Read the repo metadata dict to bootstrap ourselves. We need the docker repository,
|
||||
@@ -444,7 +444,7 @@ def generate_and_persist_registry_entry(
|
||||
|
||||
try:
|
||||
# Now that we have the docker repo, read the appropriate versioned metadata from GCS.
|
||||
# This metadata will differ in a few fields (e.g. in prerelease mode, dockerImageTag will contain the actual prerelease tag `1.2.3-dev.abcde12345`),
|
||||
# This metadata will differ in a few fields (e.g. in prerelease mode, dockerImageTag will contain the actual prerelease tag `1.2.3-preview.abcde12`),
|
||||
# so we'll treat this as the source of truth (ish. See below for how we handle the registryOverrides field.)
|
||||
gcs_client = get_gcs_storage_client(gcs_creds=os.environ.get("GCS_CREDENTIALS"))
|
||||
bucket = gcs_client.bucket(bucket_name)
|
||||
@@ -533,7 +533,9 @@ def generate_and_persist_registry_entry(
|
||||
|
||||
# For latest versions that are disabled, delete any existing registry entry to remove it from the registry
|
||||
if (
|
||||
"-rc" not in metadata_dict["data"]["dockerImageTag"] and "-dev" not in metadata_dict["data"]["dockerImageTag"]
|
||||
"-rc" not in metadata_dict["data"]["dockerImageTag"]
|
||||
and "-dev" not in metadata_dict["data"]["dockerImageTag"]
|
||||
and "-preview" not in metadata_dict["data"]["dockerImageTag"]
|
||||
) and not metadata_dict["data"]["registryOverrides"][registry_type]["enabled"]:
|
||||
logger.info(
|
||||
f"{registry_type} is not enabled: deleting existing {registry_type} registry entry for {metadata_dict['data']['dockerRepository']} at latest path."
|
||||
|
||||
@@ -5,7 +5,7 @@ data:
|
||||
connectorType: source
|
||||
dockerRepository: airbyte/image-exists-1
|
||||
githubIssueLabel: source-alloydb-strict-encrypt
|
||||
dockerImageTag: 2.0.0-dev.cf3628ccf3
|
||||
dockerImageTag: 2.0.0-preview.cf3628c
|
||||
documentationUrl: https://docs.airbyte.com/integrations/sources/existingsource
|
||||
connectorSubtype: database
|
||||
releaseStage: generally_available
|
||||
|
||||
@@ -231,7 +231,7 @@ def test_upload_prerelease(mocker, valid_metadata_yaml_files, tmp_path):
|
||||
mocker.patch.object(commands.click, "secho")
|
||||
mocker.patch.object(commands, "upload_metadata_to_gcs")
|
||||
|
||||
prerelease_tag = "0.3.0-dev.6d33165120"
|
||||
prerelease_tag = "0.3.0-preview.6d33165"
|
||||
bucket = "my-bucket"
|
||||
metadata_file_path = valid_metadata_yaml_files[0]
|
||||
validator_opts = ValidatorOptions(docs_path=str(tmp_path), prerelease_tag=prerelease_tag)
|
||||
|
||||
@@ -582,7 +582,7 @@ def test_upload_metadata_to_gcs_invalid_docker_images(mocker, invalid_metadata_u
|
||||
def test_upload_metadata_to_gcs_with_prerelease(mocker, valid_metadata_upload_files, tmp_path):
|
||||
mocker.spy(gcs_upload, "_file_upload")
|
||||
mocker.spy(gcs_upload, "upload_file_if_changed")
|
||||
prerelease_image_tag = "1.5.6-dev.f80318f754"
|
||||
prerelease_image_tag = "1.5.6-preview.f80318f"
|
||||
|
||||
for valid_metadata_upload_file in valid_metadata_upload_files:
|
||||
tmp_metadata_file_path = tmp_path / "metadata.yaml"
|
||||
@@ -701,7 +701,7 @@ def test_upload_metadata_to_gcs_release_candidate(mocker, get_fixture_path, tmp_
|
||||
)
|
||||
assert metadata.data.releases.rolloutConfiguration.enableProgressiveRollout
|
||||
|
||||
prerelease_tag = "1.5.6-dev.f80318f754" if prerelease else None
|
||||
prerelease_tag = "1.5.6-preview.f80318f" if prerelease else None
|
||||
|
||||
upload_info = gcs_upload.upload_metadata_to_gcs(
|
||||
"my_bucket",
|
||||
|
||||
@@ -110,14 +110,14 @@ class PublishConnectorContext(ConnectorContext):
|
||||
|
||||
@property
|
||||
def pre_release_suffix(self) -> str:
|
||||
return self.git_revision[:10]
|
||||
return self.git_revision[:7]
|
||||
|
||||
@property
|
||||
def docker_image_tag(self) -> str:
|
||||
# get the docker image tag from the parent class
|
||||
metadata_tag = super().docker_image_tag
|
||||
if self.pre_release:
|
||||
return f"{metadata_tag}-dev.{self.pre_release_suffix}"
|
||||
return f"{metadata_tag}-preview.{self.pre_release_suffix}"
|
||||
else:
|
||||
return metadata_tag
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from pipelines.helpers.utils import raise_if_not_user
|
||||
from pipelines.models.steps import STEP_PARAMS, Step, StepResult
|
||||
|
||||
# Pin the PyAirbyte version to avoid updates from breaking CI
|
||||
PYAIRBYTE_VERSION = "0.20.2"
|
||||
PYAIRBYTE_VERSION = "0.35.1"
|
||||
|
||||
|
||||
class PytestStep(Step, ABC):
|
||||
|
||||
@@ -156,7 +156,8 @@ class TestPyAirbyteValidationTests:
|
||||
result = await PyAirbyteValidation(context_for_valid_connector)._run(mocker.MagicMock())
|
||||
assert isinstance(result, StepResult)
|
||||
assert result.status == StepStatus.SUCCESS
|
||||
assert "Getting `spec` output from connector..." in result.stdout
|
||||
# Verify the connector name appears in output (stable across PyAirbyte versions)
|
||||
assert context_for_valid_connector.connector.technical_name in (result.stdout + result.stderr)
|
||||
|
||||
async def test__run_validation_skip_unpublished_connector(
|
||||
self,
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
cdkVersion=0.1.86
|
||||
cdkVersion=0.1.89
|
||||
JunitMethodExecutionTimeout=10m
|
||||
|
||||
@@ -2,7 +2,7 @@ data:
|
||||
connectorSubtype: database
|
||||
connectorType: destination
|
||||
definitionId: ce0d828e-1dc4-496c-b122-2da42e637e48
|
||||
dockerImageTag: 2.1.16-rc.2
|
||||
dockerImageTag: 2.1.18
|
||||
dockerRepository: airbyte/destination-clickhouse
|
||||
githubIssueLabel: destination-clickhouse
|
||||
icon: clickhouse.svg
|
||||
@@ -27,7 +27,7 @@ data:
|
||||
releaseStage: generally_available
|
||||
releases:
|
||||
rolloutConfiguration:
|
||||
enableProgressiveRollout: true
|
||||
enableProgressiveRollout: false
|
||||
breakingChanges:
|
||||
2.0.0:
|
||||
message: "This connector has been re-written from scratch. Data will now be typed and stored in final (non-raw) tables. The connector may require changes to its configuration to function properly and downstream pipelines may be affected. Warning: SSH tunneling is in Beta."
|
||||
|
||||
@@ -54,8 +54,11 @@ class ClickhouseSqlGenerator {
|
||||
// Check if cursor column type is valid for ClickHouse ReplacingMergeTree
|
||||
val cursor = tableSchema.getCursor().firstOrNull()
|
||||
val cursorType = cursor?.let { finalSchema[it]?.type }
|
||||
|
||||
val useCursorAsVersion =
|
||||
cursorType != null && isValidVersionColumn(cursor, cursorType)
|
||||
val versionColumn =
|
||||
if (cursorType?.isValidVersionColumnType() ?: false) {
|
||||
if (useCursorAsVersion) {
|
||||
"`$cursor`"
|
||||
} else {
|
||||
// Fallback to _airbyte_extracted_at if no cursor is specified or cursor
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
package io.airbyte.integrations.destination.clickhouse.client
|
||||
|
||||
import io.airbyte.cdk.load.table.CDC_CURSOR_COLUMN
|
||||
import io.airbyte.integrations.destination.clickhouse.client.ClickhouseSqlTypes.VALID_VERSION_COLUMN_TYPES
|
||||
|
||||
object ClickhouseSqlTypes {
|
||||
@@ -23,4 +24,9 @@ object ClickhouseSqlTypes {
|
||||
)
|
||||
}
|
||||
|
||||
fun String.isValidVersionColumnType() = VALID_VERSION_COLUMN_TYPES.contains(this)
|
||||
// Warning: if any munging changes the name of the CDC column name this will break.
|
||||
// Currently, that is not the case.
|
||||
fun isValidVersionColumn(name: String, type: String) =
|
||||
// CDC cursors cannot be used as a version column since they are null
|
||||
// during the initial CDC snapshot.
|
||||
name != CDC_CURSOR_COLUMN && VALID_VERSION_COLUMN_TYPES.contains(type)
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.clickhouse.config
|
||||
|
||||
import io.airbyte.cdk.load.command.DestinationStream
|
||||
import io.airbyte.cdk.load.data.Transformations.Companion.toAlphanumericAndUnderscore
|
||||
import io.airbyte.cdk.load.schema.model.TableName
|
||||
import io.airbyte.cdk.load.table.ColumnNameGenerator
|
||||
import io.airbyte.cdk.load.table.FinalTableNameGenerator
|
||||
import io.airbyte.integrations.destination.clickhouse.spec.ClickhouseConfiguration
|
||||
import jakarta.inject.Singleton
|
||||
import java.util.Locale
|
||||
import java.util.UUID
|
||||
|
||||
@Singleton
|
||||
class ClickhouseFinalTableNameGenerator(private val config: ClickhouseConfiguration) :
|
||||
FinalTableNameGenerator {
|
||||
override fun getTableName(streamDescriptor: DestinationStream.Descriptor) =
|
||||
TableName(
|
||||
namespace =
|
||||
(streamDescriptor.namespace ?: config.resolvedDatabase)
|
||||
.toClickHouseCompatibleName(),
|
||||
name = streamDescriptor.name.toClickHouseCompatibleName(),
|
||||
)
|
||||
}
|
||||
|
||||
@Singleton
|
||||
class ClickhouseColumnNameGenerator : ColumnNameGenerator {
|
||||
override fun getColumnName(column: String): ColumnNameGenerator.ColumnName {
|
||||
return ColumnNameGenerator.ColumnName(
|
||||
column.toClickHouseCompatibleName(),
|
||||
column.lowercase(Locale.getDefault()).toClickHouseCompatibleName(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Transforms a string to be compatible with ClickHouse table and column names.
|
||||
*
|
||||
* @return The transformed string suitable for ClickHouse identifiers.
|
||||
*/
|
||||
fun String.toClickHouseCompatibleName(): String {
|
||||
// 1. Replace any character that is not a letter,
|
||||
// a digit (0-9), or an underscore (_) with a single underscore.
|
||||
var transformed = toAlphanumericAndUnderscore(this)
|
||||
|
||||
// 2. Ensure the identifier does not start with a digit.
|
||||
// If it starts with a digit, prepend an underscore.
|
||||
if (transformed.isNotEmpty() && transformed[0].isDigit()) {
|
||||
transformed = "_$transformed"
|
||||
}
|
||||
|
||||
// 3.Do not allow empty strings.
|
||||
if (transformed.isEmpty()) {
|
||||
return "default_name_${UUID.randomUUID()}" // A fallback name if the input results in an
|
||||
// empty string
|
||||
}
|
||||
|
||||
return transformed
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.clickhouse.schema
|
||||
|
||||
import io.airbyte.cdk.load.data.Transformations.Companion.toAlphanumericAndUnderscore
|
||||
import java.util.UUID
|
||||
|
||||
/**
|
||||
* Transforms a string to be compatible with ClickHouse table and column names.
|
||||
*
|
||||
* @return The transformed string suitable for ClickHouse identifiers.
|
||||
*/
|
||||
fun String.toClickHouseCompatibleName(): String {
|
||||
// 1. Replace any character that is not a letter,
|
||||
// a digit (0-9), or an underscore (_) with a single underscore.
|
||||
var transformed = toAlphanumericAndUnderscore(this)
|
||||
|
||||
// 2.Do not allow empty strings.
|
||||
if (transformed.isEmpty()) {
|
||||
return "default_name_${UUID.randomUUID()}" // A fallback name if the input results in an
|
||||
// empty string
|
||||
}
|
||||
|
||||
// 3. Ensure the identifier does not start with a digit.
|
||||
// If it starts with a digit, prepend an underscore.
|
||||
if (transformed[0].isDigit()) {
|
||||
transformed = "_$transformed"
|
||||
}
|
||||
|
||||
return transformed
|
||||
}
|
||||
@@ -29,8 +29,7 @@ import io.airbyte.cdk.load.schema.model.StreamTableSchema
|
||||
import io.airbyte.cdk.load.schema.model.TableName
|
||||
import io.airbyte.cdk.load.table.TempTableNameGenerator
|
||||
import io.airbyte.integrations.destination.clickhouse.client.ClickhouseSqlTypes
|
||||
import io.airbyte.integrations.destination.clickhouse.client.isValidVersionColumnType
|
||||
import io.airbyte.integrations.destination.clickhouse.config.toClickHouseCompatibleName
|
||||
import io.airbyte.integrations.destination.clickhouse.client.isValidVersionColumn
|
||||
import io.airbyte.integrations.destination.clickhouse.spec.ClickhouseConfiguration
|
||||
import jakarta.inject.Singleton
|
||||
|
||||
@@ -100,7 +99,7 @@ class ClickhouseTableSchemaMapper(
|
||||
if (cursor != null) {
|
||||
// Check if the cursor column type is valid for ClickHouse ReplacingMergeTree
|
||||
val cursorColumnType = tableSchema.columnSchema.finalSchema[cursor]!!.type
|
||||
if (cursorColumnType.isValidVersionColumnType()) {
|
||||
if (isValidVersionColumn(cursor, cursorColumnType)) {
|
||||
// Cursor column is valid, use it as version column
|
||||
add(cursor) // Make cursor column non-nullable too
|
||||
}
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.clickhouse.component
|
||||
|
||||
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures
|
||||
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures.NEGATIVE_HIGH_PRECISION_FLOAT
|
||||
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures.POSITIVE_HIGH_PRECISION_FLOAT
|
||||
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures.SMALLEST_NEGATIVE_FLOAT32
|
||||
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures.SMALLEST_NEGATIVE_FLOAT64
|
||||
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures.SMALLEST_POSITIVE_FLOAT32
|
||||
import io.airbyte.cdk.load.component.DataCoercionNumberFixtures.SMALLEST_POSITIVE_FLOAT64
|
||||
import io.airbyte.cdk.load.component.DataCoercionSuite
|
||||
import io.airbyte.cdk.load.component.TableOperationsClient
|
||||
import io.airbyte.cdk.load.component.TestTableOperationsClient
|
||||
import io.airbyte.cdk.load.component.toArgs
|
||||
import io.airbyte.cdk.load.data.AirbyteValue
|
||||
import io.airbyte.cdk.load.dataflow.transform.ValueCoercer
|
||||
import io.airbyte.cdk.load.schema.TableSchemaFactory
|
||||
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Reason
|
||||
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
|
||||
import org.junit.jupiter.params.ParameterizedTest
|
||||
import org.junit.jupiter.params.provider.MethodSource
|
||||
|
||||
@MicronautTest(environments = ["component"], resolveParameters = false)
|
||||
class ClickhouseDataCoercionTest(
|
||||
override val coercer: ValueCoercer,
|
||||
override val opsClient: TableOperationsClient,
|
||||
override val testClient: TestTableOperationsClient,
|
||||
override val schemaFactory: TableSchemaFactory,
|
||||
) : DataCoercionSuite {
|
||||
@ParameterizedTest
|
||||
// We use clickhouse's Int64 type for integers
|
||||
@MethodSource("io.airbyte.cdk.load.component.DataCoercionIntegerFixtures#int64")
|
||||
override fun `handle integer values`(
|
||||
inputValue: AirbyteValue,
|
||||
expectedValue: Any?,
|
||||
expectedChangeReason: Reason?
|
||||
) {
|
||||
super.`handle integer values`(inputValue, expectedValue, expectedChangeReason)
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource(
|
||||
"io.airbyte.integrations.destination.clickhouse.component.ClickhouseDataCoercionTest#numbers"
|
||||
)
|
||||
override fun `handle number values`(
|
||||
inputValue: AirbyteValue,
|
||||
expectedValue: Any?,
|
||||
expectedChangeReason: Reason?
|
||||
) {
|
||||
super.`handle number values`(inputValue, expectedValue, expectedChangeReason)
|
||||
}
|
||||
|
||||
companion object {
|
||||
/**
|
||||
* destination-clickhouse doesn't set a change reason when truncating high-precision numbers
|
||||
* (https://github.com/airbytehq/airbyte-internal-issues/issues/15401)
|
||||
*/
|
||||
@JvmStatic
|
||||
fun numbers() =
|
||||
DataCoercionNumberFixtures.numeric38_9
|
||||
.map {
|
||||
when (it.name) {
|
||||
POSITIVE_HIGH_PRECISION_FLOAT,
|
||||
NEGATIVE_HIGH_PRECISION_FLOAT,
|
||||
SMALLEST_POSITIVE_FLOAT32,
|
||||
SMALLEST_NEGATIVE_FLOAT32,
|
||||
SMALLEST_POSITIVE_FLOAT64,
|
||||
SMALLEST_NEGATIVE_FLOAT64 -> it.copy(changeReason = null)
|
||||
else -> it
|
||||
}
|
||||
}
|
||||
.toArgs()
|
||||
}
|
||||
}
|
||||
@@ -22,7 +22,7 @@ class ClickhouseTableSchemaEvolutionTest(
|
||||
override val client: TableSchemaEvolutionClient,
|
||||
override val opsClient: TableOperationsClient,
|
||||
override val testClient: TestTableOperationsClient,
|
||||
override val schemaFactory: TableSchemaFactory
|
||||
override val schemaFactory: TableSchemaFactory,
|
||||
) : TableSchemaEvolutionSuite {
|
||||
private val allTypesTableSchema =
|
||||
TableSchema(
|
||||
|
||||
@@ -16,7 +16,7 @@ import io.airbyte.cdk.load.data.TimestampWithTimezoneValue
|
||||
import io.airbyte.cdk.load.data.TimestampWithoutTimezoneValue
|
||||
import io.airbyte.cdk.load.test.util.ExpectedRecordMapper
|
||||
import io.airbyte.cdk.load.test.util.OutputRecord
|
||||
import io.airbyte.integrations.destination.clickhouse.config.toClickHouseCompatibleName
|
||||
import io.airbyte.integrations.destination.clickhouse.schema.toClickHouseCompatibleName
|
||||
import java.math.RoundingMode
|
||||
import java.time.LocalTime
|
||||
import java.time.ZoneOffset
|
||||
|
||||
@@ -30,8 +30,8 @@ import io.airbyte.cdk.load.write.UnknownTypesBehavior
|
||||
import io.airbyte.integrations.destination.clickhouse.ClickhouseConfigUpdater
|
||||
import io.airbyte.integrations.destination.clickhouse.ClickhouseContainerHelper
|
||||
import io.airbyte.integrations.destination.clickhouse.Utils
|
||||
import io.airbyte.integrations.destination.clickhouse.config.toClickHouseCompatibleName
|
||||
import io.airbyte.integrations.destination.clickhouse.fixtures.ClickhouseExpectedRecordMapper
|
||||
import io.airbyte.integrations.destination.clickhouse.schema.toClickHouseCompatibleName
|
||||
import io.airbyte.integrations.destination.clickhouse.spec.ClickhouseConfiguration
|
||||
import io.airbyte.integrations.destination.clickhouse.spec.ClickhouseConfigurationFactory
|
||||
import io.airbyte.integrations.destination.clickhouse.spec.ClickhouseSpecificationOss
|
||||
|
||||
@@ -21,7 +21,6 @@ import io.airbyte.cdk.load.schema.model.StreamTableSchema
|
||||
import io.airbyte.cdk.load.schema.model.TableName
|
||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
||||
import io.airbyte.cdk.load.table.TempTableNameGenerator
|
||||
import io.airbyte.integrations.destination.clickhouse.config.ClickhouseFinalTableNameGenerator
|
||||
import io.mockk.coEvery
|
||||
import io.mockk.coVerify
|
||||
import io.mockk.coVerifyOrder
|
||||
@@ -39,8 +38,6 @@ class ClickhouseAirbyteClientTest {
|
||||
// Mocks
|
||||
private val client: ClickHouseClientRaw = mockk(relaxed = true)
|
||||
private val clickhouseSqlGenerator: ClickhouseSqlGenerator = mockk(relaxed = true)
|
||||
private val clickhouseFinalTableNameGenerator: ClickhouseFinalTableNameGenerator =
|
||||
mockk(relaxed = true)
|
||||
private val tempTableNameGenerator: TempTableNameGenerator = mockk(relaxed = true)
|
||||
|
||||
// Client
|
||||
@@ -105,7 +102,6 @@ class ClickhouseAirbyteClientTest {
|
||||
alterTableStatement
|
||||
coEvery { clickhouseAirbyteClient.execute(alterTableStatement) } returns
|
||||
mockk(relaxed = true)
|
||||
every { clickhouseFinalTableNameGenerator.getTableName(any()) } returns mockTableName
|
||||
|
||||
mockCHSchemaWithAirbyteColumns()
|
||||
|
||||
@@ -172,7 +168,6 @@ class ClickhouseAirbyteClientTest {
|
||||
|
||||
coEvery { clickhouseAirbyteClient.execute(any()) } returns mockk(relaxed = true)
|
||||
every { tempTableNameGenerator.generate(any()) } returns tempTableName
|
||||
every { clickhouseFinalTableNameGenerator.getTableName(any()) } returns finalTableName
|
||||
|
||||
mockCHSchemaWithAirbyteColumns()
|
||||
|
||||
@@ -226,8 +221,6 @@ class ClickhouseAirbyteClientTest {
|
||||
fun `test ensure schema matches fails if no airbyte columns`() = runTest {
|
||||
val finalTableName = TableName("fin", "al")
|
||||
|
||||
every { clickhouseFinalTableNameGenerator.getTableName(any()) } returns finalTableName
|
||||
|
||||
val columnMapping = ColumnNameMapping(mapOf())
|
||||
val stream =
|
||||
mockk<DestinationStream> {
|
||||
|
||||
@@ -2,13 +2,13 @@
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.clickhouse.config
|
||||
package io.airbyte.integrations.destination.clickhouse.schema
|
||||
|
||||
import java.util.UUID
|
||||
import org.junit.jupiter.api.Assertions
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
class ClickhouseNameGeneratorTest {
|
||||
class ClickhouseNamingUtilsTest {
|
||||
@Test
|
||||
fun `toClickHouseCompatibleName replaces special characters with underscores`() {
|
||||
Assertions.assertEquals("hello_world", "hello world".toClickHouseCompatibleName())
|
||||
@@ -6,7 +6,7 @@ data:
|
||||
connectorSubtype: database
|
||||
connectorType: destination
|
||||
definitionId: 25c5221d-dce2-4163-ade9-739ef790f503
|
||||
dockerImageTag: 3.0.5-rc.1
|
||||
dockerImageTag: 3.0.5
|
||||
dockerRepository: airbyte/destination-postgres
|
||||
documentationUrl: https://docs.airbyte.com/integrations/destinations/postgres
|
||||
githubIssueLabel: destination-postgres
|
||||
@@ -22,7 +22,7 @@ data:
|
||||
enabled: true
|
||||
releases:
|
||||
rolloutConfiguration:
|
||||
enableProgressiveRollout: true
|
||||
enableProgressiveRollout: false
|
||||
breakingChanges:
|
||||
3.0.0:
|
||||
message: >
|
||||
|
||||
@@ -4,12 +4,16 @@
|
||||
|
||||
package io.airbyte.integrations.destination.postgres.client
|
||||
|
||||
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
|
||||
import io.airbyte.cdk.ConfigErrorException
|
||||
import io.airbyte.cdk.load.command.DestinationStream
|
||||
import io.airbyte.cdk.load.component.ColumnChangeset
|
||||
import io.airbyte.cdk.load.component.ColumnType
|
||||
import io.airbyte.cdk.load.component.TableColumns
|
||||
import io.airbyte.cdk.load.component.TableOperationsClient
|
||||
import io.airbyte.cdk.load.component.TableSchema
|
||||
import io.airbyte.cdk.load.component.TableSchemaEvolutionClient
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAMES
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
|
||||
import io.airbyte.cdk.load.schema.model.TableName
|
||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
||||
@@ -26,6 +30,11 @@ import javax.sql.DataSource
|
||||
private val log = KotlinLogging.logger {}
|
||||
|
||||
@Singleton
|
||||
@SuppressFBWarnings(
|
||||
value = ["SQL_NONCONSTANT_STRING_PASSED_TO_EXECUTE"],
|
||||
justification =
|
||||
"There is little chance of SQL injection. There is also little need for statement reuse. The basic statement is more readable than the prepared statement."
|
||||
)
|
||||
class PostgresAirbyteClient(
|
||||
private val dataSource: DataSource,
|
||||
private val sqlGenerator: PostgresDirectLoadSqlGenerator,
|
||||
@@ -53,6 +62,29 @@ class PostgresAirbyteClient(
|
||||
null
|
||||
}
|
||||
|
||||
override suspend fun namespaceExists(namespace: String): Boolean {
|
||||
return executeQuery(
|
||||
"""
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM information_schema.schemata
|
||||
WHERE schema_name = '$namespace'
|
||||
)
|
||||
"""
|
||||
) { rs -> rs.next() && rs.getBoolean(1) }
|
||||
}
|
||||
|
||||
override suspend fun tableExists(table: TableName): Boolean {
|
||||
return executeQuery(
|
||||
"""
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_schema = '${table.namespace}'
|
||||
AND table_name = '${table.name}'
|
||||
)
|
||||
"""
|
||||
) { rs -> rs.next() && rs.getBoolean(1) }
|
||||
}
|
||||
|
||||
override suspend fun createNamespace(namespace: String) {
|
||||
try {
|
||||
execute(sqlGenerator.createNamespace(namespace))
|
||||
@@ -171,14 +203,26 @@ class PostgresAirbyteClient(
|
||||
}
|
||||
|
||||
override suspend fun discoverSchema(tableName: TableName): TableSchema {
|
||||
TODO("Not yet implemented")
|
||||
val columnsInDb = getColumnsFromDbForDiscovery(tableName)
|
||||
val hasAllAirbyteColumns = columnsInDb.keys.containsAll(COLUMN_NAMES)
|
||||
|
||||
if (!hasAllAirbyteColumns) {
|
||||
val message =
|
||||
"The target table ($tableName) already exists in the destination, but does not contain Airbyte's internal columns. Airbyte can only sync to Airbyte-controlled tables. To fix this error, you must either delete the target table or add a prefix in the connection configuration in order to sync to a separate table in the destination."
|
||||
log.error { message }
|
||||
throw ConfigErrorException(message)
|
||||
}
|
||||
|
||||
// Filter out Airbyte columns
|
||||
val userColumns = columnsInDb.filterKeys { it !in COLUMN_NAMES }
|
||||
return TableSchema(userColumns)
|
||||
}
|
||||
|
||||
override fun computeSchema(
|
||||
stream: DestinationStream,
|
||||
columnNameMapping: ColumnNameMapping
|
||||
): TableSchema {
|
||||
TODO("Not yet implemented")
|
||||
return TableSchema(stream.tableSchema.columnSchema.finalSchema)
|
||||
}
|
||||
|
||||
override suspend fun applyChangeset(
|
||||
@@ -188,9 +232,73 @@ class PostgresAirbyteClient(
|
||||
expectedColumns: TableColumns,
|
||||
columnChangeset: ColumnChangeset
|
||||
) {
|
||||
TODO("Not yet implemented")
|
||||
if (
|
||||
columnChangeset.columnsToAdd.isNotEmpty() ||
|
||||
columnChangeset.columnsToDrop.isNotEmpty() ||
|
||||
columnChangeset.columnsToChange.isNotEmpty()
|
||||
) {
|
||||
log.info { "Summary of the table alterations:" }
|
||||
log.info { "Added columns: ${columnChangeset.columnsToAdd}" }
|
||||
log.info { "Deleted columns: ${columnChangeset.columnsToDrop}" }
|
||||
log.info { "Modified columns: ${columnChangeset.columnsToChange}" }
|
||||
|
||||
// Convert from TableColumns format to Column format
|
||||
val columnsToAdd =
|
||||
columnChangeset.columnsToAdd
|
||||
.map { (name, type) -> Column(name, type.type, type.nullable) }
|
||||
.toSet()
|
||||
val columnsToRemove =
|
||||
columnChangeset.columnsToDrop
|
||||
.map { (name, type) -> Column(name, type.type, type.nullable) }
|
||||
.toSet()
|
||||
val columnsToModify =
|
||||
columnChangeset.columnsToChange
|
||||
.map { (name, change) ->
|
||||
Column(name, change.newType.type, change.newType.nullable)
|
||||
}
|
||||
.toSet()
|
||||
val columnsInDb =
|
||||
(columnChangeset.columnsToRetain +
|
||||
columnChangeset.columnsToDrop +
|
||||
columnChangeset.columnsToChange.mapValues { it.value.originalType })
|
||||
.map { (name, type) -> Column(name, type.type, type.nullable) }
|
||||
.toSet()
|
||||
|
||||
execute(
|
||||
sqlGenerator.matchSchemas(
|
||||
tableName = tableName,
|
||||
columnsToAdd = columnsToAdd,
|
||||
columnsToRemove = columnsToRemove,
|
||||
columnsToModify = columnsToModify,
|
||||
columnsInDb = columnsInDb,
|
||||
recreatePrimaryKeyIndex = false,
|
||||
primaryKeyColumnNames = emptyList(),
|
||||
recreateCursorIndex = false,
|
||||
cursorColumnName = null,
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets columns from the database including their types for schema discovery. Unlike
|
||||
* [getColumnsFromDb], this returns all columns including Airbyte metadata columns.
|
||||
*/
|
||||
private fun getColumnsFromDbForDiscovery(tableName: TableName): Map<String, ColumnType> =
|
||||
executeQuery(sqlGenerator.getTableSchema(tableName)) { rs ->
|
||||
val columnsInDb: MutableMap<String, ColumnType> = mutableMapOf()
|
||||
while (rs.next()) {
|
||||
val columnName = rs.getString(COLUMN_NAME_COLUMN)
|
||||
val dataType = rs.getString("data_type")
|
||||
// PostgreSQL's information_schema always returns 'YES' or 'NO' for is_nullable
|
||||
val isNullable = rs.getString("is_nullable") == "YES"
|
||||
|
||||
columnsInDb[columnName] = ColumnType(normalizePostgresType(dataType), isNullable)
|
||||
}
|
||||
|
||||
columnsInDb
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the primary key index matches the current stream configuration. If the primary keys
|
||||
* have changed (detected by comparing columns in the index), then this will return true,
|
||||
|
||||
@@ -531,7 +531,7 @@ class PostgresDirectLoadSqlGenerator(
|
||||
|
||||
fun getTableSchema(tableName: TableName): String =
|
||||
"""
|
||||
SELECT column_name, data_type
|
||||
SELECT column_name, data_type, is_nullable
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = '${tableName.namespace}'
|
||||
AND table_name = '${tableName.name}';
|
||||
|
||||
@@ -49,6 +49,7 @@ class PostgresWriter(
|
||||
override fun createStreamLoader(stream: DestinationStream): StreamLoader {
|
||||
val initialStatus = initialStatuses[stream]!!
|
||||
val realTableName = stream.tableSchema.tableNames.finalTableName!!
|
||||
|
||||
val tempTableName = tempTableNameGenerator.generate(realTableName)
|
||||
val columnNameMapping =
|
||||
ColumnNameMapping(stream.tableSchema.columnSchema.inputToFinalColumnNames)
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.postgres.component
|
||||
|
||||
import io.airbyte.cdk.load.util.Jsons
|
||||
import io.airbyte.integrations.destination.postgres.PostgresConfigUpdater
|
||||
import io.airbyte.integrations.destination.postgres.PostgresContainerHelper
|
||||
import io.airbyte.integrations.destination.postgres.spec.PostgresConfiguration
|
||||
import io.airbyte.integrations.destination.postgres.spec.PostgresConfigurationFactory
|
||||
import io.airbyte.integrations.destination.postgres.spec.PostgresSpecificationOss
|
||||
import io.micronaut.context.annotation.Factory
|
||||
import io.micronaut.context.annotation.Primary
|
||||
import io.micronaut.context.annotation.Requires
|
||||
import jakarta.inject.Singleton
|
||||
|
||||
@Requires(env = ["component"])
|
||||
@Factory
|
||||
class PostgresComponentTestConfigFactory {
|
||||
@Singleton
|
||||
@Primary
|
||||
fun config(): PostgresConfiguration {
|
||||
// Start the postgres container
|
||||
PostgresContainerHelper.start()
|
||||
|
||||
// Create a minimal config JSON and update it with container details
|
||||
val configJson =
|
||||
"""
|
||||
{
|
||||
"host": "replace_me_host",
|
||||
"port": "replace_me_port",
|
||||
"database": "replace_me_database",
|
||||
"schema": "public",
|
||||
"username": "replace_me_username",
|
||||
"password": "replace_me_password",
|
||||
"ssl": false
|
||||
}
|
||||
"""
|
||||
|
||||
val updatedConfig = PostgresConfigUpdater().update(configJson)
|
||||
val spec = Jsons.readValue(updatedConfig, PostgresSpecificationOss::class.java)
|
||||
return PostgresConfigurationFactory().makeWithoutExceptionHandling(spec)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.postgres.component
|
||||
|
||||
import io.airbyte.cdk.load.component.ColumnType
|
||||
import io.airbyte.cdk.load.component.TableOperationsFixtures
|
||||
import io.airbyte.cdk.load.component.TableSchema
|
||||
|
||||
object PostgresComponentTestFixtures {
|
||||
// PostgreSQL uses lowercase column names by default (no transformation needed)
|
||||
val testMapping = TableOperationsFixtures.TEST_MAPPING
|
||||
val idAndTestMapping = TableOperationsFixtures.ID_AND_TEST_MAPPING
|
||||
val idTestWithCdcMapping = TableOperationsFixtures.ID_TEST_WITH_CDC_MAPPING
|
||||
|
||||
val allTypesTableSchema =
|
||||
TableSchema(
|
||||
mapOf(
|
||||
"string" to ColumnType("varchar", true),
|
||||
"boolean" to ColumnType("boolean", true),
|
||||
"integer" to ColumnType("bigint", true),
|
||||
"number" to ColumnType("decimal", true),
|
||||
"date" to ColumnType("date", true),
|
||||
"timestamp_tz" to ColumnType("timestamp with time zone", true),
|
||||
"timestamp_ntz" to ColumnType("timestamp", true),
|
||||
"time_tz" to ColumnType("time with time zone", true),
|
||||
"time_ntz" to ColumnType("time", true),
|
||||
"array" to ColumnType("jsonb", true),
|
||||
"object" to ColumnType("jsonb", true),
|
||||
"unknown" to ColumnType("jsonb", true),
|
||||
)
|
||||
)
|
||||
|
||||
val allTypesColumnNameMapping = TableOperationsFixtures.ALL_TYPES_MAPPING
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.postgres.component
|
||||
|
||||
import io.airbyte.cdk.load.component.TableOperationsFixtures
|
||||
import io.airbyte.cdk.load.component.TableOperationsSuite
|
||||
import io.airbyte.cdk.load.schema.TableSchemaFactory
|
||||
import io.airbyte.integrations.destination.postgres.client.PostgresAirbyteClient
|
||||
import io.airbyte.integrations.destination.postgres.component.PostgresComponentTestFixtures.idTestWithCdcMapping
|
||||
import io.airbyte.integrations.destination.postgres.component.PostgresComponentTestFixtures.testMapping
|
||||
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
|
||||
import jakarta.inject.Inject
|
||||
import org.junit.jupiter.api.Disabled
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
@MicronautTest(environments = ["component"])
|
||||
class PostgresTableOperationsTest(
|
||||
override val client: PostgresAirbyteClient,
|
||||
override val testClient: PostgresTestTableOperationsClient,
|
||||
) : TableOperationsSuite {
|
||||
|
||||
@Inject override lateinit var schemaFactory: TableSchemaFactory
|
||||
|
||||
@Test
|
||||
override fun `connect to database`() {
|
||||
super.`connect to database`()
|
||||
}
|
||||
|
||||
@Test
|
||||
override fun `create and drop namespaces`() {
|
||||
super.`create and drop namespaces`()
|
||||
}
|
||||
|
||||
@Test
|
||||
override fun `create and drop tables`() {
|
||||
super.`create and drop tables`()
|
||||
}
|
||||
|
||||
@Test
|
||||
override fun `insert records`() {
|
||||
super.`insert records`(
|
||||
inputRecords = TableOperationsFixtures.SINGLE_TEST_RECORD_INPUT,
|
||||
expectedRecords = TableOperationsFixtures.SINGLE_TEST_RECORD_EXPECTED,
|
||||
columnNameMapping = testMapping,
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
override fun `count table rows`() {
|
||||
super.`count table rows`(columnNameMapping = testMapping)
|
||||
}
|
||||
|
||||
@Test
|
||||
override fun `overwrite tables`() {
|
||||
super.`overwrite tables`(
|
||||
sourceInputRecords = TableOperationsFixtures.OVERWRITE_SOURCE_RECORDS,
|
||||
targetInputRecords = TableOperationsFixtures.OVERWRITE_TARGET_RECORDS,
|
||||
expectedRecords = TableOperationsFixtures.OVERWRITE_EXPECTED_RECORDS,
|
||||
columnNameMapping = testMapping,
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
override fun `copy tables`() {
|
||||
super.`copy tables`(
|
||||
sourceInputRecords = TableOperationsFixtures.OVERWRITE_SOURCE_RECORDS,
|
||||
targetInputRecords = TableOperationsFixtures.OVERWRITE_TARGET_RECORDS,
|
||||
expectedRecords = TableOperationsFixtures.COPY_EXPECTED_RECORDS,
|
||||
columnNameMapping = testMapping,
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
override fun `get generation id`() {
|
||||
super.`get generation id`(columnNameMapping = testMapping)
|
||||
}
|
||||
|
||||
// TODO: Re-enable when CDK TableOperationsSuite is fixed to use ID_AND_TEST_SCHEMA for target
|
||||
// table instead of TEST_INTEGER_SCHEMA (the Dedupe mode requires the id column as primary key)
|
||||
@Disabled("CDK TableOperationsSuite bug: target table schema missing 'id' column for Dedupe")
|
||||
@Test
|
||||
override fun `upsert tables`() {
|
||||
super.`upsert tables`(
|
||||
sourceInputRecords = TableOperationsFixtures.UPSERT_SOURCE_RECORDS,
|
||||
targetInputRecords = TableOperationsFixtures.UPSERT_TARGET_RECORDS,
|
||||
expectedRecords = TableOperationsFixtures.UPSERT_EXPECTED_RECORDS,
|
||||
columnNameMapping = idTestWithCdcMapping,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.postgres.component
|
||||
|
||||
import io.airbyte.cdk.load.command.ImportType
|
||||
import io.airbyte.cdk.load.component.TableSchemaEvolutionFixtures
|
||||
import io.airbyte.cdk.load.component.TableSchemaEvolutionSuite
|
||||
import io.airbyte.cdk.load.schema.TableSchemaFactory
|
||||
import io.airbyte.integrations.destination.postgres.client.PostgresAirbyteClient
|
||||
import io.airbyte.integrations.destination.postgres.component.PostgresComponentTestFixtures.allTypesColumnNameMapping
|
||||
import io.airbyte.integrations.destination.postgres.component.PostgresComponentTestFixtures.allTypesTableSchema
|
||||
import io.airbyte.integrations.destination.postgres.component.PostgresComponentTestFixtures.idAndTestMapping
|
||||
import io.airbyte.integrations.destination.postgres.component.PostgresComponentTestFixtures.testMapping
|
||||
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
@MicronautTest(environments = ["component"], resolveParameters = false)
|
||||
class PostgresTableSchemaEvolutionTest(
|
||||
override val client: PostgresAirbyteClient,
|
||||
override val opsClient: PostgresAirbyteClient,
|
||||
override val testClient: PostgresTestTableOperationsClient,
|
||||
override val schemaFactory: TableSchemaFactory,
|
||||
) : TableSchemaEvolutionSuite {
|
||||
|
||||
@Test
|
||||
fun `discover recognizes all data types`() {
|
||||
super.`discover recognizes all data types`(allTypesTableSchema, allTypesColumnNameMapping)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `computeSchema handles all data types`() {
|
||||
super.`computeSchema handles all data types`(allTypesTableSchema, allTypesColumnNameMapping)
|
||||
}
|
||||
|
||||
@Test
|
||||
override fun `noop diff`() {
|
||||
super.`noop diff`(testMapping)
|
||||
}
|
||||
|
||||
@Test
|
||||
override fun `changeset is correct when adding a column`() {
|
||||
super.`changeset is correct when adding a column`(testMapping, idAndTestMapping)
|
||||
}
|
||||
|
||||
@Test
|
||||
override fun `changeset is correct when dropping a column`() {
|
||||
super.`changeset is correct when dropping a column`(idAndTestMapping, testMapping)
|
||||
}
|
||||
|
||||
@Test
|
||||
override fun `changeset is correct when changing a column's type`() {
|
||||
super.`changeset is correct when changing a column's type`(testMapping)
|
||||
}
|
||||
|
||||
@Test
|
||||
override fun `apply changeset - handle sync mode append`() {
|
||||
super.`apply changeset - handle sync mode append`()
|
||||
}
|
||||
|
||||
@Test
|
||||
override fun `apply changeset - handle changing sync mode from append to dedup`() {
|
||||
super.`apply changeset - handle changing sync mode from append to dedup`()
|
||||
}
|
||||
|
||||
@Test
|
||||
override fun `apply changeset - handle changing sync mode from dedup to append`() {
|
||||
super.`apply changeset - handle changing sync mode from dedup to append`()
|
||||
}
|
||||
|
||||
@Test
|
||||
override fun `apply changeset - handle sync mode dedup`() {
|
||||
super.`apply changeset - handle sync mode dedup`()
|
||||
}
|
||||
|
||||
override fun `apply changeset`(
|
||||
initialStreamImportType: ImportType,
|
||||
modifiedStreamImportType: ImportType,
|
||||
) {
|
||||
super.`apply changeset`(
|
||||
initialColumnNameMapping =
|
||||
TableSchemaEvolutionFixtures.APPLY_CHANGESET_INITIAL_COLUMN_MAPPING,
|
||||
modifiedColumnNameMapping =
|
||||
TableSchemaEvolutionFixtures.APPLY_CHANGESET_MODIFIED_COLUMN_MAPPING,
|
||||
TableSchemaEvolutionFixtures.APPLY_CHANGESET_EXPECTED_EXTRACTED_AT,
|
||||
initialStreamImportType,
|
||||
modifiedStreamImportType,
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
override fun `change from string type to unknown type`() {
|
||||
super.`change from string type to unknown type`(
|
||||
idAndTestMapping,
|
||||
idAndTestMapping,
|
||||
TableSchemaEvolutionFixtures.STRING_TO_UNKNOWN_TYPE_INPUT_RECORDS,
|
||||
TableSchemaEvolutionFixtures.STRING_TO_UNKNOWN_TYPE_EXPECTED_RECORDS,
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
override fun `change from unknown type to string type`() {
|
||||
super.`change from unknown type to string type`(
|
||||
idAndTestMapping,
|
||||
idAndTestMapping,
|
||||
TableSchemaEvolutionFixtures.UNKNOWN_TO_STRING_TYPE_INPUT_RECORDS,
|
||||
TableSchemaEvolutionFixtures.UNKNOWN_TO_STRING_TYPE_EXPECTED_RECORDS,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,257 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.postgres.component
|
||||
|
||||
import io.airbyte.cdk.load.component.TestTableOperationsClient
|
||||
import io.airbyte.cdk.load.data.AirbyteValue
|
||||
import io.airbyte.cdk.load.schema.model.TableName
|
||||
import io.airbyte.cdk.load.util.Jsons
|
||||
import io.airbyte.integrations.destination.postgres.client.PostgresAirbyteClient
|
||||
import io.micronaut.context.annotation.Requires
|
||||
import jakarta.inject.Singleton
|
||||
import java.time.OffsetDateTime
|
||||
import java.time.ZoneOffset
|
||||
import java.time.format.DateTimeFormatter
|
||||
import javax.sql.DataSource
|
||||
|
||||
@Requires(env = ["component"])
|
||||
@Singleton
|
||||
class PostgresTestTableOperationsClient(
|
||||
private val dataSource: DataSource,
|
||||
private val client: PostgresAirbyteClient,
|
||||
) : TestTableOperationsClient {
|
||||
override suspend fun ping() {
|
||||
dataSource.connection.use { connection ->
|
||||
connection.createStatement().use { statement -> statement.executeQuery("SELECT 1") }
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun dropNamespace(namespace: String) {
|
||||
dataSource.connection.use { connection ->
|
||||
connection.createStatement().use { statement ->
|
||||
statement.execute("DROP SCHEMA IF EXISTS \"$namespace\" CASCADE")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun insertRecords(table: TableName, records: List<Map<String, AirbyteValue>>) {
|
||||
if (records.isEmpty()) return
|
||||
|
||||
// Get column types from database to handle jsonb columns properly
|
||||
val columnTypes = getColumnTypes(table)
|
||||
|
||||
// Get all unique columns from ALL records to handle sparse data (e.g., CDC deletion column)
|
||||
val columns = records.flatMap { it.keys }.distinct().toList()
|
||||
val columnNames = columns.joinToString(", ") { "\"$it\"" }
|
||||
val placeholders = columns.indices.joinToString(", ") { "?" }
|
||||
|
||||
val sql =
|
||||
"""
|
||||
INSERT INTO "${table.namespace}"."${table.name}" ($columnNames)
|
||||
VALUES ($placeholders)
|
||||
"""
|
||||
|
||||
dataSource.connection.use { connection ->
|
||||
connection.prepareStatement(sql).use { statement ->
|
||||
for (record in records) {
|
||||
columns.forEachIndexed { index, column ->
|
||||
val value = record[column]
|
||||
val columnType = columnTypes[column]
|
||||
setParameterValue(statement, index + 1, value, columnType)
|
||||
}
|
||||
statement.addBatch()
|
||||
}
|
||||
statement.executeBatch()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun getColumnTypes(table: TableName): Map<String, String> {
|
||||
val columnTypes = mutableMapOf<String, String>()
|
||||
dataSource.connection.use { connection ->
|
||||
connection.createStatement().use { statement ->
|
||||
statement
|
||||
.executeQuery(
|
||||
"""
|
||||
SELECT column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = '${table.namespace}'
|
||||
AND table_name = '${table.name}'
|
||||
"""
|
||||
)
|
||||
.use { resultSet ->
|
||||
while (resultSet.next()) {
|
||||
columnTypes[resultSet.getString("column_name")] =
|
||||
resultSet.getString("data_type")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return columnTypes
|
||||
}
|
||||
|
||||
private fun setParameterValue(
|
||||
statement: java.sql.PreparedStatement,
|
||||
index: Int,
|
||||
value: AirbyteValue?,
|
||||
columnType: String?
|
||||
) {
|
||||
// If column is jsonb, serialize any value as JSON
|
||||
if (columnType == "jsonb") {
|
||||
if (value == null || value is io.airbyte.cdk.load.data.NullValue) {
|
||||
statement.setNull(index, java.sql.Types.OTHER)
|
||||
} else {
|
||||
val pgObject = org.postgresql.util.PGobject()
|
||||
pgObject.type = "jsonb"
|
||||
pgObject.value = serializeToJson(value)
|
||||
statement.setObject(index, pgObject)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
when (value) {
|
||||
null,
|
||||
is io.airbyte.cdk.load.data.NullValue -> statement.setNull(index, java.sql.Types.NULL)
|
||||
is io.airbyte.cdk.load.data.StringValue -> statement.setString(index, value.value)
|
||||
is io.airbyte.cdk.load.data.IntegerValue ->
|
||||
statement.setLong(index, value.value.toLong())
|
||||
is io.airbyte.cdk.load.data.NumberValue -> statement.setBigDecimal(index, value.value)
|
||||
is io.airbyte.cdk.load.data.BooleanValue -> statement.setBoolean(index, value.value)
|
||||
is io.airbyte.cdk.load.data.TimestampWithTimezoneValue -> {
|
||||
val offsetDateTime = OffsetDateTime.parse(value.value.toString())
|
||||
statement.setObject(index, offsetDateTime)
|
||||
}
|
||||
is io.airbyte.cdk.load.data.TimestampWithoutTimezoneValue -> {
|
||||
val localDateTime = java.time.LocalDateTime.parse(value.value.toString())
|
||||
statement.setObject(index, localDateTime)
|
||||
}
|
||||
is io.airbyte.cdk.load.data.DateValue -> {
|
||||
val localDate = java.time.LocalDate.parse(value.value.toString())
|
||||
statement.setObject(index, localDate)
|
||||
}
|
||||
is io.airbyte.cdk.load.data.TimeWithTimezoneValue -> {
|
||||
statement.setString(index, value.value.toString())
|
||||
}
|
||||
is io.airbyte.cdk.load.data.TimeWithoutTimezoneValue -> {
|
||||
val localTime = java.time.LocalTime.parse(value.value.toString())
|
||||
statement.setObject(index, localTime)
|
||||
}
|
||||
is io.airbyte.cdk.load.data.ObjectValue -> {
|
||||
val pgObject = org.postgresql.util.PGobject()
|
||||
pgObject.type = "jsonb"
|
||||
pgObject.value = Jsons.writeValueAsString(value.values)
|
||||
statement.setObject(index, pgObject)
|
||||
}
|
||||
is io.airbyte.cdk.load.data.ArrayValue -> {
|
||||
val pgObject = org.postgresql.util.PGobject()
|
||||
pgObject.type = "jsonb"
|
||||
pgObject.value = Jsons.writeValueAsString(value.values)
|
||||
statement.setObject(index, pgObject)
|
||||
}
|
||||
else -> {
|
||||
// For unknown types, try to serialize as string
|
||||
statement.setString(index, value.toString())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun serializeToJson(value: AirbyteValue): String {
|
||||
return when (value) {
|
||||
is io.airbyte.cdk.load.data.StringValue -> Jsons.writeValueAsString(value.value)
|
||||
is io.airbyte.cdk.load.data.IntegerValue -> value.value.toString()
|
||||
is io.airbyte.cdk.load.data.NumberValue -> value.value.toString()
|
||||
is io.airbyte.cdk.load.data.BooleanValue -> value.value.toString()
|
||||
is io.airbyte.cdk.load.data.ObjectValue -> Jsons.writeValueAsString(value.values)
|
||||
is io.airbyte.cdk.load.data.ArrayValue -> Jsons.writeValueAsString(value.values)
|
||||
is io.airbyte.cdk.load.data.NullValue -> "null"
|
||||
else -> Jsons.writeValueAsString(value.toString())
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun readTable(table: TableName): List<Map<String, Any>> {
|
||||
dataSource.connection.use { connection ->
|
||||
connection.createStatement().use { statement ->
|
||||
statement
|
||||
.executeQuery("""SELECT * FROM "${table.namespace}"."${table.name}"""")
|
||||
.use { resultSet ->
|
||||
val metaData = resultSet.metaData
|
||||
val columnCount = metaData.columnCount
|
||||
val result = mutableListOf<Map<String, Any>>()
|
||||
|
||||
while (resultSet.next()) {
|
||||
val row = mutableMapOf<String, Any>()
|
||||
for (i in 1..columnCount) {
|
||||
val columnName = metaData.getColumnName(i)
|
||||
val columnType = metaData.getColumnTypeName(i)
|
||||
when (columnType.lowercase()) {
|
||||
"timestamptz" -> {
|
||||
val value =
|
||||
resultSet.getObject(i, OffsetDateTime::class.java)
|
||||
if (value != null) {
|
||||
val formattedTimestamp =
|
||||
DateTimeFormatter.ISO_OFFSET_DATE_TIME.format(
|
||||
value.withOffsetSameInstant(ZoneOffset.UTC)
|
||||
)
|
||||
row[columnName] = formattedTimestamp
|
||||
}
|
||||
}
|
||||
"timestamp" -> {
|
||||
val value = resultSet.getTimestamp(i)
|
||||
if (value != null) {
|
||||
val localDateTime = value.toLocalDateTime()
|
||||
row[columnName] =
|
||||
DateTimeFormatter.ISO_LOCAL_DATE_TIME.format(
|
||||
localDateTime
|
||||
)
|
||||
}
|
||||
}
|
||||
"jsonb",
|
||||
"json" -> {
|
||||
val stringValue: String? = resultSet.getString(i)
|
||||
if (stringValue != null) {
|
||||
val parsedValue =
|
||||
Jsons.readValue(stringValue, Any::class.java)
|
||||
val actualValue =
|
||||
when (parsedValue) {
|
||||
is Int -> parsedValue.toLong()
|
||||
else -> parsedValue
|
||||
}
|
||||
row[columnName] = actualValue
|
||||
}
|
||||
}
|
||||
else -> {
|
||||
val value = resultSet.getObject(i)
|
||||
if (value != null) {
|
||||
// For varchar columns that may contain JSON (from
|
||||
// schema evolution),
|
||||
// normalize the JSON to compact format for comparison
|
||||
if (
|
||||
value is String &&
|
||||
(value.startsWith("{") || value.startsWith("["))
|
||||
) {
|
||||
try {
|
||||
val parsed =
|
||||
Jsons.readValue(value, Any::class.java)
|
||||
row[columnName] =
|
||||
Jsons.writeValueAsString(parsed)
|
||||
} catch (_: Exception) {
|
||||
row[columnName] = value
|
||||
}
|
||||
} else {
|
||||
row[columnName] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result.add(row)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -267,7 +267,7 @@ class PostgresRawDataDumper(
|
||||
.lowercase()
|
||||
.toPostgresCompatibleName()
|
||||
|
||||
val fullyQualifiedTableName = "$rawNamespace.$rawName"
|
||||
val fullyQualifiedTableName = "\"$rawNamespace\".\"$rawName\""
|
||||
|
||||
// Check if table exists first
|
||||
val tableExistsQuery =
|
||||
@@ -302,6 +302,26 @@ class PostgresRawDataDumper(
|
||||
false
|
||||
}
|
||||
|
||||
// Build the column name mapping from original names to transformed names
|
||||
// We use the stream schema to get the original field names, then transform them
|
||||
// using the postgres name transformation logic
|
||||
val finalToInputColumnNames = mutableMapOf<String, String>()
|
||||
if (stream.schema is ObjectType) {
|
||||
val objectSchema = stream.schema as ObjectType
|
||||
for (fieldName in objectSchema.properties.keys) {
|
||||
val transformedName = fieldName.toPostgresCompatibleName()
|
||||
// Map transformed name back to original name
|
||||
finalToInputColumnNames[transformedName] = fieldName
|
||||
}
|
||||
}
|
||||
// Also check if inputToFinalColumnNames mapping is available
|
||||
val inputToFinalColumnNames =
|
||||
stream.tableSchema.columnSchema.inputToFinalColumnNames
|
||||
// Add entries from the existing mapping (in case it was populated)
|
||||
for ((input, final) in inputToFinalColumnNames) {
|
||||
finalToInputColumnNames[final] = input
|
||||
}
|
||||
|
||||
while (resultSet.next()) {
|
||||
val rawData =
|
||||
if (hasDataColumn) {
|
||||
@@ -313,8 +333,22 @@ class PostgresRawDataDumper(
|
||||
else -> dataObject?.toString() ?: "{}"
|
||||
}
|
||||
|
||||
// Parse JSON to AirbyteValue, then coerce it to match the schema
|
||||
dataJson?.deserializeToNode()?.toAirbyteValue() ?: NullValue
|
||||
// Parse JSON to AirbyteValue, then map column names back to originals
|
||||
val parsedValue =
|
||||
dataJson?.deserializeToNode()?.toAirbyteValue() ?: NullValue
|
||||
// If the parsed value is an ObjectValue, map the column names back
|
||||
if (parsedValue is ObjectValue) {
|
||||
val mappedProperties = linkedMapOf<String, AirbyteValue>()
|
||||
for ((key, value) in parsedValue.values) {
|
||||
// Map final column name back to input column name if mapping
|
||||
// exists
|
||||
val originalKey = finalToInputColumnNames[key] ?: key
|
||||
mappedProperties[originalKey] = value
|
||||
}
|
||||
ObjectValue(mappedProperties)
|
||||
} else {
|
||||
parsedValue
|
||||
}
|
||||
} else {
|
||||
// Typed table mode: read from individual columns and reconstruct the
|
||||
// object
|
||||
@@ -333,10 +367,19 @@ class PostgresRawDataDumper(
|
||||
|
||||
for ((fieldName, fieldType) in objectSchema.properties) {
|
||||
try {
|
||||
// Map input field name to the transformed final column name
|
||||
// First check the inputToFinalColumnNames mapping, then
|
||||
// fall
|
||||
// back to applying postgres transformation directly
|
||||
val transformedColumnName =
|
||||
inputToFinalColumnNames[fieldName]
|
||||
?: fieldName.toPostgresCompatibleName()
|
||||
|
||||
// Try to find the actual column name (case-insensitive
|
||||
// lookup)
|
||||
val actualColumnName =
|
||||
columnMap[fieldName.lowercase()] ?: fieldName
|
||||
columnMap[transformedColumnName.lowercase()]
|
||||
?: transformedColumnName
|
||||
val columnValue = resultSet.getObject(actualColumnName)
|
||||
properties[fieldName] =
|
||||
when (columnValue) {
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
testExecutionConcurrency=-1
|
||||
cdkVersion=0.1.82
|
||||
cdkVersion=0.1.91
|
||||
JunitMethodExecutionTimeout=10m
|
||||
|
||||
@@ -6,7 +6,7 @@ data:
|
||||
connectorSubtype: database
|
||||
connectorType: destination
|
||||
definitionId: 424892c4-daac-4491-b35d-c6688ba547ba
|
||||
dockerImageTag: 4.0.31
|
||||
dockerImageTag: 4.0.32-rc.1
|
||||
dockerRepository: airbyte/destination-snowflake
|
||||
documentationUrl: https://docs.airbyte.com/integrations/destinations/snowflake
|
||||
githubIssueLabel: destination-snowflake
|
||||
@@ -31,6 +31,8 @@ data:
|
||||
enabled: true
|
||||
releaseStage: generally_available
|
||||
releases:
|
||||
rolloutConfiguration:
|
||||
enableProgressiveRollout: true
|
||||
breakingChanges:
|
||||
2.0.0:
|
||||
message: Remove GCS/S3 loading method support.
|
||||
|
||||
@@ -11,15 +11,18 @@ import io.airbyte.cdk.load.check.CheckOperationV2
|
||||
import io.airbyte.cdk.load.check.DestinationCheckerV2
|
||||
import io.airbyte.cdk.load.config.DataChannelMedium
|
||||
import io.airbyte.cdk.load.dataflow.config.AggregatePublishingConfig
|
||||
import io.airbyte.cdk.load.orchestration.db.DefaultTempTableNameGenerator
|
||||
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator
|
||||
import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator
|
||||
import io.airbyte.cdk.load.table.TempTableNameGenerator
|
||||
import io.airbyte.cdk.output.OutputConsumer
|
||||
import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier
|
||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory
|
||||
import io.airbyte.integrations.destination.snowflake.spec.UsernamePasswordAuthConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRawRecordFormatter
|
||||
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
|
||||
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter
|
||||
import io.micronaut.context.annotation.Factory
|
||||
import io.micronaut.context.annotation.Primary
|
||||
import io.micronaut.context.annotation.Requires
|
||||
@@ -204,6 +207,17 @@ class SnowflakeBeanFactory {
|
||||
outputConsumer: OutputConsumer,
|
||||
) = CheckOperationV2(destinationChecker, outputConsumer)
|
||||
|
||||
@Singleton
|
||||
fun snowflakeRecordFormatter(
|
||||
snowflakeConfiguration: SnowflakeConfiguration
|
||||
): SnowflakeRecordFormatter {
|
||||
return if (snowflakeConfiguration.legacyRawTablesOnly) {
|
||||
SnowflakeRawRecordFormatter()
|
||||
} else {
|
||||
SnowflakeSchemaRecordFormatter()
|
||||
}
|
||||
}
|
||||
|
||||
@Singleton
|
||||
fun aggregatePublishingConfig(dataChannelMedium: DataChannelMedium): AggregatePublishingConfig {
|
||||
// NOT speed mode
|
||||
|
||||
@@ -13,13 +13,17 @@ import io.airbyte.cdk.load.data.FieldType
|
||||
import io.airbyte.cdk.load.data.ObjectType
|
||||
import io.airbyte.cdk.load.data.StringType
|
||||
import io.airbyte.cdk.load.message.Meta
|
||||
import io.airbyte.cdk.load.schema.model.ColumnSchema
|
||||
import io.airbyte.cdk.load.schema.model.StreamTableSchema
|
||||
import io.airbyte.cdk.load.schema.model.TableName
|
||||
import io.airbyte.cdk.load.schema.model.TableNames
|
||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
||||
import io.airbyte.cdk.load.table.TableName
|
||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
||||
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
|
||||
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter
|
||||
import jakarta.inject.Singleton
|
||||
import java.time.OffsetDateTime
|
||||
import java.util.UUID
|
||||
@@ -31,7 +35,7 @@ internal const val CHECK_COLUMN_NAME = "test_key"
|
||||
class SnowflakeChecker(
|
||||
private val snowflakeAirbyteClient: SnowflakeAirbyteClient,
|
||||
private val snowflakeConfiguration: SnowflakeConfiguration,
|
||||
private val snowflakeColumnUtils: SnowflakeColumnUtils,
|
||||
private val columnManager: SnowflakeColumnManager,
|
||||
) : DestinationCheckerV2 {
|
||||
|
||||
override fun check() {
|
||||
@@ -46,11 +50,40 @@ class SnowflakeChecker(
|
||||
Meta.AirbyteMetaFields.GENERATION_ID.fieldName to AirbyteValue.from(0),
|
||||
CHECK_COLUMN_NAME.toSnowflakeCompatibleName() to AirbyteValue.from("test-value")
|
||||
)
|
||||
val outputSchema = snowflakeConfiguration.schema.toSnowflakeCompatibleName()
|
||||
val outputSchema =
|
||||
if (snowflakeConfiguration.legacyRawTablesOnly) {
|
||||
snowflakeConfiguration.schema
|
||||
} else {
|
||||
snowflakeConfiguration.schema.toSnowflakeCompatibleName()
|
||||
}
|
||||
val tableName =
|
||||
"_airbyte_connection_test_${
|
||||
UUID.randomUUID().toString().replace("-".toRegex(), "")}".toSnowflakeCompatibleName()
|
||||
val qualifiedTableName = TableName(namespace = outputSchema, name = tableName)
|
||||
val tableSchema =
|
||||
StreamTableSchema(
|
||||
tableNames =
|
||||
TableNames(
|
||||
finalTableName = qualifiedTableName,
|
||||
tempTableName = qualifiedTableName
|
||||
),
|
||||
columnSchema =
|
||||
ColumnSchema(
|
||||
inputToFinalColumnNames =
|
||||
mapOf(
|
||||
CHECK_COLUMN_NAME to CHECK_COLUMN_NAME.toSnowflakeCompatibleName()
|
||||
),
|
||||
finalSchema =
|
||||
mapOf(
|
||||
CHECK_COLUMN_NAME.toSnowflakeCompatibleName() to
|
||||
io.airbyte.cdk.load.component.ColumnType("VARCHAR", false)
|
||||
),
|
||||
inputSchema =
|
||||
mapOf(CHECK_COLUMN_NAME to FieldType(StringType, nullable = false))
|
||||
),
|
||||
importType = Append
|
||||
)
|
||||
|
||||
val destinationStream =
|
||||
DestinationStream(
|
||||
unmappedNamespace = outputSchema,
|
||||
@@ -63,7 +96,8 @@ class SnowflakeChecker(
|
||||
generationId = 0L,
|
||||
minimumGenerationId = 0L,
|
||||
syncId = 0L,
|
||||
namespaceMapper = NamespaceMapper()
|
||||
namespaceMapper = NamespaceMapper(),
|
||||
tableSchema = tableSchema
|
||||
)
|
||||
runBlocking {
|
||||
try {
|
||||
@@ -75,14 +109,14 @@ class SnowflakeChecker(
|
||||
replace = true,
|
||||
)
|
||||
|
||||
val columns = snowflakeAirbyteClient.describeTable(qualifiedTableName)
|
||||
val snowflakeInsertBuffer =
|
||||
SnowflakeInsertBuffer(
|
||||
tableName = qualifiedTableName,
|
||||
columns = columns,
|
||||
snowflakeClient = snowflakeAirbyteClient,
|
||||
snowflakeConfiguration = snowflakeConfiguration,
|
||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
||||
columnSchema = tableSchema.columnSchema,
|
||||
columnManager = columnManager,
|
||||
snowflakeRecordFormatter = SnowflakeSchemaRecordFormatter(),
|
||||
)
|
||||
|
||||
snowflakeInsertBuffer.accumulate(data)
|
||||
|
||||
@@ -13,18 +13,16 @@ import io.airbyte.cdk.load.component.TableColumns
|
||||
import io.airbyte.cdk.load.component.TableOperationsClient
|
||||
import io.airbyte.cdk.load.component.TableSchema
|
||||
import io.airbyte.cdk.load.component.TableSchemaEvolutionClient
|
||||
import io.airbyte.cdk.load.schema.model.TableName
|
||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
||||
import io.airbyte.cdk.load.table.TableName
|
||||
import io.airbyte.cdk.load.util.deserializeToNode
|
||||
import io.airbyte.integrations.destination.snowflake.db.ColumnDefinition
|
||||
import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier
|
||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS
|
||||
import io.airbyte.integrations.destination.snowflake.sql.NOT_NULL
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
|
||||
import io.airbyte.integrations.destination.snowflake.sql.andLog
|
||||
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
|
||||
import io.github.oshai.kotlinlogging.KotlinLogging
|
||||
import jakarta.inject.Singleton
|
||||
import java.sql.ResultSet
|
||||
@@ -41,13 +39,10 @@ private val log = KotlinLogging.logger {}
|
||||
class SnowflakeAirbyteClient(
|
||||
private val dataSource: DataSource,
|
||||
private val sqlGenerator: SnowflakeDirectLoadSqlGenerator,
|
||||
private val snowflakeColumnUtils: SnowflakeColumnUtils,
|
||||
private val snowflakeConfiguration: SnowflakeConfiguration,
|
||||
private val columnManager: SnowflakeColumnManager,
|
||||
) : TableOperationsClient, TableSchemaEvolutionClient {
|
||||
|
||||
private val airbyteColumnNames =
|
||||
snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet()
|
||||
|
||||
override suspend fun countTable(tableName: TableName): Long? =
|
||||
try {
|
||||
dataSource.connection.use { connection ->
|
||||
@@ -126,7 +121,7 @@ class SnowflakeAirbyteClient(
|
||||
columnNameMapping: ColumnNameMapping,
|
||||
replace: Boolean
|
||||
) {
|
||||
execute(sqlGenerator.createTable(stream, tableName, columnNameMapping, replace))
|
||||
execute(sqlGenerator.createTable(tableName, stream.tableSchema, replace))
|
||||
execute(sqlGenerator.createSnowflakeStage(tableName))
|
||||
}
|
||||
|
||||
@@ -163,7 +158,15 @@ class SnowflakeAirbyteClient(
|
||||
sourceTableName: TableName,
|
||||
targetTableName: TableName
|
||||
) {
|
||||
execute(sqlGenerator.copyTable(columnNameMapping, sourceTableName, targetTableName))
|
||||
// Get all column names from the mapping (both meta columns and user columns)
|
||||
val columnNames = buildSet {
|
||||
// Add Airbyte meta columns (using uppercase constants)
|
||||
addAll(columnManager.getMetaColumnNames())
|
||||
// Add user columns from mapping
|
||||
addAll(columnNameMapping.values)
|
||||
}
|
||||
|
||||
execute(sqlGenerator.copyTable(columnNames, sourceTableName, targetTableName))
|
||||
}
|
||||
|
||||
override suspend fun upsertTable(
|
||||
@@ -172,9 +175,7 @@ class SnowflakeAirbyteClient(
|
||||
sourceTableName: TableName,
|
||||
targetTableName: TableName
|
||||
) {
|
||||
execute(
|
||||
sqlGenerator.upsertTable(stream, columnNameMapping, sourceTableName, targetTableName)
|
||||
)
|
||||
execute(sqlGenerator.upsertTable(stream.tableSchema, sourceTableName, targetTableName))
|
||||
}
|
||||
|
||||
override suspend fun dropTable(tableName: TableName) {
|
||||
@@ -206,7 +207,7 @@ class SnowflakeAirbyteClient(
|
||||
stream: DestinationStream,
|
||||
columnNameMapping: ColumnNameMapping
|
||||
): TableSchema {
|
||||
return TableSchema(getColumnsFromStream(stream, columnNameMapping))
|
||||
return TableSchema(stream.tableSchema.columnSchema.finalSchema)
|
||||
}
|
||||
|
||||
override suspend fun applyChangeset(
|
||||
@@ -253,7 +254,7 @@ class SnowflakeAirbyteClient(
|
||||
val columnName = escapeJsonIdentifier(rs.getString("name"))
|
||||
|
||||
// Filter out airbyte columns
|
||||
if (airbyteColumnNames.contains(columnName)) {
|
||||
if (columnManager.getMetaColumnNames().contains(columnName)) {
|
||||
continue
|
||||
}
|
||||
val dataType = rs.getString("type").takeWhile { char -> char != '(' }
|
||||
@@ -271,49 +272,6 @@ class SnowflakeAirbyteClient(
|
||||
}
|
||||
}
|
||||
|
||||
internal fun getColumnsFromStream(
|
||||
stream: DestinationStream,
|
||||
columnNameMapping: ColumnNameMapping
|
||||
): Map<String, ColumnType> =
|
||||
snowflakeColumnUtils
|
||||
.columnsAndTypes(stream.schema.asColumns(), columnNameMapping)
|
||||
.filter { column -> column.columnName !in airbyteColumnNames }
|
||||
.associate { column ->
|
||||
// columnsAndTypes returns types as either `FOO` or `FOO NOT NULL`.
|
||||
// so check for that suffix.
|
||||
val nullable = !column.columnType.endsWith(NOT_NULL)
|
||||
val type =
|
||||
column.columnType
|
||||
.takeWhile { char ->
|
||||
// This is to remove any precision parts of the dialect type
|
||||
char != '('
|
||||
}
|
||||
.removeSuffix(NOT_NULL)
|
||||
.trim()
|
||||
|
||||
column.columnName to ColumnType(type, nullable)
|
||||
}
|
||||
|
||||
internal fun generateSchemaChanges(
|
||||
columnsInDb: Set<ColumnDefinition>,
|
||||
columnsInStream: Set<ColumnDefinition>
|
||||
): Triple<Set<ColumnDefinition>, Set<ColumnDefinition>, Set<ColumnDefinition>> {
|
||||
val addedColumns =
|
||||
columnsInStream.filter { it.name !in columnsInDb.map { col -> col.name } }.toSet()
|
||||
val deletedColumns =
|
||||
columnsInDb.filter { it.name !in columnsInStream.map { col -> col.name } }.toSet()
|
||||
val commonColumns =
|
||||
columnsInStream.filter { it.name in columnsInDb.map { col -> col.name } }.toSet()
|
||||
val modifiedColumns =
|
||||
commonColumns
|
||||
.filter {
|
||||
val dbType = columnsInDb.find { column -> it.name == column.name }?.type
|
||||
it.type != dbType
|
||||
}
|
||||
.toSet()
|
||||
return Triple(addedColumns, deletedColumns, modifiedColumns)
|
||||
}
|
||||
|
||||
override suspend fun getGenerationId(tableName: TableName): Long =
|
||||
try {
|
||||
dataSource.connection.use { connection ->
|
||||
@@ -326,7 +284,7 @@ class SnowflakeAirbyteClient(
|
||||
* format. In order to make sure these strings will match any column names
|
||||
* that we have formatted in-memory, re-apply the escaping.
|
||||
*/
|
||||
resultSet.getLong(snowflakeColumnUtils.getGenerationIdColumnName())
|
||||
resultSet.getLong(columnManager.getGenerationIdColumnName())
|
||||
} else {
|
||||
log.warn {
|
||||
"No generation ID found for table ${tableName.toPrettyString()}, returning 0"
|
||||
@@ -351,8 +309,8 @@ class SnowflakeAirbyteClient(
|
||||
execute(sqlGenerator.putInStage(tableName, tempFilePath))
|
||||
}
|
||||
|
||||
fun copyFromStage(tableName: TableName, filename: String) {
|
||||
execute(sqlGenerator.copyFromStage(tableName, filename))
|
||||
fun copyFromStage(tableName: TableName, filename: String, columnNames: List<String>) {
|
||||
execute(sqlGenerator.copyFromStage(tableName, filename, columnNames))
|
||||
}
|
||||
|
||||
fun describeTable(tableName: TableName): LinkedHashMap<String, String> =
|
||||
|
||||
@@ -4,47 +4,41 @@
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.dataflow
|
||||
|
||||
import io.airbyte.cdk.load.command.DestinationCatalog
|
||||
import io.airbyte.cdk.load.dataflow.aggregate.Aggregate
|
||||
import io.airbyte.cdk.load.dataflow.aggregate.AggregateFactory
|
||||
import io.airbyte.cdk.load.dataflow.aggregate.StoreKey
|
||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig
|
||||
import io.airbyte.cdk.load.table.TableName
|
||||
import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
|
||||
import io.airbyte.cdk.load.write.StreamStateStore
|
||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
||||
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
|
||||
import io.micronaut.cache.annotation.CacheConfig
|
||||
import io.micronaut.cache.annotation.Cacheable
|
||||
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
|
||||
import jakarta.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
@CacheConfig("table-columns")
|
||||
// class has to be open to make the cache stuff work
|
||||
open class SnowflakeAggregateFactory(
|
||||
class SnowflakeAggregateFactory(
|
||||
private val snowflakeClient: SnowflakeAirbyteClient,
|
||||
private val streamStateStore: StreamStateStore<DirectLoadTableExecutionConfig>,
|
||||
private val snowflakeConfiguration: SnowflakeConfiguration,
|
||||
private val snowflakeColumnUtils: SnowflakeColumnUtils,
|
||||
private val catalog: DestinationCatalog,
|
||||
private val columnManager: SnowflakeColumnManager,
|
||||
private val snowflakeRecordFormatter: SnowflakeRecordFormatter,
|
||||
) : AggregateFactory {
|
||||
override fun create(key: StoreKey): Aggregate {
|
||||
val stream = catalog.getStream(key)
|
||||
val tableName = streamStateStore.get(key)!!.tableName
|
||||
|
||||
val buffer =
|
||||
SnowflakeInsertBuffer(
|
||||
tableName = tableName,
|
||||
columns = getTableColumns(tableName),
|
||||
snowflakeClient = snowflakeClient,
|
||||
snowflakeConfiguration = snowflakeConfiguration,
|
||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
||||
columnSchema = stream.tableSchema.columnSchema,
|
||||
columnManager = columnManager,
|
||||
snowflakeRecordFormatter = snowflakeRecordFormatter,
|
||||
)
|
||||
return SnowflakeAggregate(buffer = buffer)
|
||||
}
|
||||
|
||||
// We assume that a table isn't getting altered _during_ a sync.
|
||||
// This allows us to only SHOW COLUMNS once per table per sync,
|
||||
// rather than refetching it on every aggregate.
|
||||
@Cacheable
|
||||
// function has to be open to make caching work
|
||||
internal open fun getTableColumns(tableName: TableName) =
|
||||
snowflakeClient.describeTable(tableName)
|
||||
}
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.db
|
||||
|
||||
/**
|
||||
* Jdbc destination column definition representation
|
||||
*
|
||||
* @param name
|
||||
* @param type
|
||||
*/
|
||||
data class ColumnDefinition(val name: String, val type: String)
|
||||
@@ -4,17 +4,17 @@
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.db
|
||||
|
||||
import io.airbyte.cdk.load.command.DestinationCatalog
|
||||
import io.airbyte.cdk.load.component.TableOperationsClient
|
||||
import io.airbyte.cdk.load.orchestration.db.BaseDirectLoadInitialStatusGatherer
|
||||
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator
|
||||
import io.airbyte.cdk.load.table.BaseDirectLoadInitialStatusGatherer
|
||||
import jakarta.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
class SnowflakeDirectLoadDatabaseInitialStatusGatherer(
|
||||
tableOperationsClient: TableOperationsClient,
|
||||
tempTableNameGenerator: TempTableNameGenerator,
|
||||
catalog: DestinationCatalog,
|
||||
) :
|
||||
BaseDirectLoadInitialStatusGatherer(
|
||||
tableOperationsClient,
|
||||
tempTableNameGenerator,
|
||||
catalog,
|
||||
)
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.db
|
||||
|
||||
import io.airbyte.cdk.ConfigErrorException
|
||||
import io.airbyte.cdk.load.command.DestinationStream
|
||||
import io.airbyte.cdk.load.orchestration.db.ColumnNameGenerator
|
||||
import io.airbyte.cdk.load.orchestration.db.FinalTableNameGenerator
|
||||
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TypingDedupingUtil
|
||||
import io.airbyte.cdk.load.table.TableName
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.sql.QUOTE
|
||||
import jakarta.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
class SnowflakeFinalTableNameGenerator(private val config: SnowflakeConfiguration) :
|
||||
FinalTableNameGenerator {
|
||||
override fun getTableName(streamDescriptor: DestinationStream.Descriptor): TableName {
|
||||
val namespace = streamDescriptor.namespace ?: config.schema
|
||||
return if (!config.legacyRawTablesOnly) {
|
||||
TableName(
|
||||
namespace = namespace.toSnowflakeCompatibleName(),
|
||||
name = streamDescriptor.name.toSnowflakeCompatibleName(),
|
||||
)
|
||||
} else {
|
||||
TableName(
|
||||
namespace = config.internalTableSchema,
|
||||
name =
|
||||
TypingDedupingUtil.concatenateRawTableName(
|
||||
namespace = escapeJsonIdentifier(namespace),
|
||||
name = escapeJsonIdentifier(streamDescriptor.name),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Singleton
|
||||
class SnowflakeColumnNameGenerator(private val config: SnowflakeConfiguration) :
|
||||
ColumnNameGenerator {
|
||||
override fun getColumnName(column: String): ColumnNameGenerator.ColumnName {
|
||||
return if (!config.legacyRawTablesOnly) {
|
||||
ColumnNameGenerator.ColumnName(
|
||||
column.toSnowflakeCompatibleName(),
|
||||
column.toSnowflakeCompatibleName(),
|
||||
)
|
||||
} else {
|
||||
ColumnNameGenerator.ColumnName(
|
||||
column,
|
||||
column,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Escapes double-quotes in a JSON identifier by doubling them. This shit is legacy -- I don't know
|
||||
* why this would be necessary but no harm in keeping it so I am keeping it.
|
||||
*
|
||||
* @return The escaped identifier.
|
||||
*/
|
||||
fun escapeJsonIdentifier(identifier: String): String {
|
||||
// Note that we don't need to escape backslashes here!
|
||||
// The only special character in an identifier is the double-quote, which needs to be
|
||||
// doubled.
|
||||
return identifier.replace(QUOTE, "$QUOTE$QUOTE")
|
||||
}
|
||||
|
||||
/**
|
||||
* Transforms a string to be compatible with Snowflake table and column names.
|
||||
*
|
||||
* @return The transformed string suitable for Snowflake identifiers.
|
||||
*/
|
||||
fun String.toSnowflakeCompatibleName(): String {
|
||||
var identifier = this
|
||||
|
||||
// Handle empty strings
|
||||
if (identifier.isEmpty()) {
|
||||
throw ConfigErrorException("Empty string is invalid identifier")
|
||||
}
|
||||
|
||||
// Snowflake scripting language does something weird when the `${` bigram shows up in the
|
||||
// script so replace these with something else.
|
||||
// For completeness, if we trigger this, also replace closing curly braces with underscores.
|
||||
if (identifier.contains("\${")) {
|
||||
identifier = identifier.replace("$", "_").replace("{", "_").replace("}", "_")
|
||||
}
|
||||
|
||||
// Escape double quotes
|
||||
identifier = escapeJsonIdentifier(identifier)
|
||||
|
||||
return identifier.uppercase()
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.schema
|
||||
|
||||
import io.airbyte.cdk.load.component.ColumnType
|
||||
import io.airbyte.cdk.load.message.Meta
|
||||
import io.airbyte.cdk.load.schema.model.ColumnSchema
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_EXTRACTED_AT
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_GENERATION_ID
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_META
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SNOWFLAKE_AB_RAW_ID
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDataType
|
||||
import jakarta.inject.Singleton
|
||||
|
||||
/**
|
||||
* Manages column names and ordering for Snowflake tables based on whether legacy raw tables mode is
|
||||
* enabled.
|
||||
*
|
||||
* TODO: We should add meta column munging and raw table support to the CDK, so this extra layer of
|
||||
* management shouldn't be necessary.
|
||||
*/
|
||||
@Singleton
|
||||
class SnowflakeColumnManager(
|
||||
private val config: SnowflakeConfiguration,
|
||||
) {
|
||||
/**
|
||||
* Get the list of column names for a table in the order they should appear in the CSV file and
|
||||
* COPY INTO statement.
|
||||
*
|
||||
* Warning: MUST match the order defined in SnowflakeRecordFormatter
|
||||
*
|
||||
* @param columnSchema The schema containing column information (ignored in raw mode)
|
||||
* @return List of column names in the correct order
|
||||
*/
|
||||
fun getTableColumnNames(columnSchema: ColumnSchema): List<String> {
|
||||
return buildList {
|
||||
addAll(getMetaColumnNames())
|
||||
addAll(columnSchema.finalSchema.keys)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the list of Airbyte meta column names. In schema mode, these are uppercase. In raw mode,
|
||||
* they are lowercase and included loaded_at
|
||||
*
|
||||
* @return Set of meta column names
|
||||
*/
|
||||
fun getMetaColumnNames(): Set<String> =
|
||||
if (config.legacyRawTablesOnly) {
|
||||
Constants.rawModeMetaColNames
|
||||
} else {
|
||||
Constants.schemaModeMetaColNames
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the Airbyte meta columns as a map of column name to ColumnType. This provides both the
|
||||
* column names and their types for table creation.
|
||||
*
|
||||
* @param columnSchema The user column schema (used to check for CDC columns)
|
||||
* @return Map of meta column names to their types
|
||||
*/
|
||||
fun getMetaColumns(): LinkedHashMap<String, ColumnType> {
|
||||
return if (config.legacyRawTablesOnly) {
|
||||
Constants.rawModeMetaColumns
|
||||
} else {
|
||||
Constants.schemaModeMetaColumns
|
||||
}
|
||||
}
|
||||
|
||||
fun getGenerationIdColumnName(): String {
|
||||
return if (config.legacyRawTablesOnly) {
|
||||
Meta.COLUMN_NAME_AB_GENERATION_ID
|
||||
} else {
|
||||
SNOWFLAKE_AB_GENERATION_ID
|
||||
}
|
||||
}
|
||||
|
||||
object Constants {
|
||||
val rawModeMetaColumns =
|
||||
linkedMapOf(
|
||||
Meta.COLUMN_NAME_AB_RAW_ID to ColumnType(SnowflakeDataType.VARCHAR.typeName, false),
|
||||
Meta.COLUMN_NAME_AB_EXTRACTED_AT to
|
||||
ColumnType(
|
||||
SnowflakeDataType.TIMESTAMP_TZ.typeName,
|
||||
false,
|
||||
),
|
||||
Meta.COLUMN_NAME_AB_META to ColumnType(SnowflakeDataType.VARIANT.typeName, false),
|
||||
Meta.COLUMN_NAME_AB_GENERATION_ID to
|
||||
ColumnType(
|
||||
SnowflakeDataType.NUMBER.typeName,
|
||||
true,
|
||||
),
|
||||
Meta.COLUMN_NAME_AB_LOADED_AT to
|
||||
ColumnType(
|
||||
SnowflakeDataType.TIMESTAMP_TZ.typeName,
|
||||
true,
|
||||
),
|
||||
)
|
||||
|
||||
val schemaModeMetaColumns =
|
||||
linkedMapOf(
|
||||
SNOWFLAKE_AB_RAW_ID to ColumnType(SnowflakeDataType.VARCHAR.typeName, false),
|
||||
SNOWFLAKE_AB_EXTRACTED_AT to
|
||||
ColumnType(SnowflakeDataType.TIMESTAMP_TZ.typeName, false),
|
||||
SNOWFLAKE_AB_META to ColumnType(SnowflakeDataType.VARIANT.typeName, false),
|
||||
SNOWFLAKE_AB_GENERATION_ID to ColumnType(SnowflakeDataType.NUMBER.typeName, true),
|
||||
)
|
||||
|
||||
val rawModeMetaColNames: Set<String> = rawModeMetaColumns.keys
|
||||
|
||||
val schemaModeMetaColNames: Set<String> = schemaModeMetaColumns.keys
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.schema
|
||||
|
||||
import io.airbyte.cdk.ConfigErrorException
|
||||
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
|
||||
|
||||
/**
|
||||
* Transforms a string to be compatible with Snowflake table and column names.
|
||||
*
|
||||
* @return The transformed string suitable for Snowflake identifiers.
|
||||
*/
|
||||
fun String.toSnowflakeCompatibleName(): String {
|
||||
var identifier = this
|
||||
|
||||
// Handle empty strings
|
||||
if (identifier.isEmpty()) {
|
||||
throw ConfigErrorException("Empty string is invalid identifier")
|
||||
}
|
||||
|
||||
// Snowflake scripting language does something weird when the `${` bigram shows up in the
|
||||
// script so replace these with something else.
|
||||
// For completeness, if we trigger this, also replace closing curly braces with underscores.
|
||||
if (identifier.contains("\${")) {
|
||||
identifier = identifier.replace("$", "_").replace("{", "_").replace("}", "_")
|
||||
}
|
||||
|
||||
// Escape double quotes
|
||||
identifier = escapeJsonIdentifier(identifier)
|
||||
|
||||
return identifier.uppercase()
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.schema
|
||||
|
||||
import io.airbyte.cdk.load.command.DestinationStream
|
||||
import io.airbyte.cdk.load.component.ColumnType
|
||||
import io.airbyte.cdk.load.data.ArrayType
|
||||
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
|
||||
import io.airbyte.cdk.load.data.BooleanType
|
||||
import io.airbyte.cdk.load.data.DateType
|
||||
import io.airbyte.cdk.load.data.FieldType
|
||||
import io.airbyte.cdk.load.data.IntegerType
|
||||
import io.airbyte.cdk.load.data.NumberType
|
||||
import io.airbyte.cdk.load.data.ObjectType
|
||||
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
|
||||
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
|
||||
import io.airbyte.cdk.load.data.StringType
|
||||
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
|
||||
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
|
||||
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
|
||||
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
|
||||
import io.airbyte.cdk.load.data.UnionType
|
||||
import io.airbyte.cdk.load.data.UnknownType
|
||||
import io.airbyte.cdk.load.message.Meta
|
||||
import io.airbyte.cdk.load.schema.TableSchemaMapper
|
||||
import io.airbyte.cdk.load.schema.model.StreamTableSchema
|
||||
import io.airbyte.cdk.load.schema.model.TableName
|
||||
import io.airbyte.cdk.load.table.TempTableNameGenerator
|
||||
import io.airbyte.cdk.load.table.TypingDedupingUtil
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDataType
|
||||
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
|
||||
import jakarta.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
class SnowflakeTableSchemaMapper(
|
||||
private val config: SnowflakeConfiguration,
|
||||
private val tempTableNameGenerator: TempTableNameGenerator,
|
||||
) : TableSchemaMapper {
|
||||
override fun toFinalTableName(desc: DestinationStream.Descriptor): TableName {
|
||||
val namespace = desc.namespace ?: config.schema
|
||||
return if (!config.legacyRawTablesOnly) {
|
||||
TableName(
|
||||
namespace = namespace.toSnowflakeCompatibleName(),
|
||||
name = desc.name.toSnowflakeCompatibleName(),
|
||||
)
|
||||
} else {
|
||||
TableName(
|
||||
namespace = config.internalTableSchema,
|
||||
name =
|
||||
TypingDedupingUtil.concatenateRawTableName(
|
||||
namespace = escapeJsonIdentifier(namespace),
|
||||
name = escapeJsonIdentifier(desc.name),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
override fun toTempTableName(tableName: TableName): TableName {
|
||||
return tempTableNameGenerator.generate(tableName)
|
||||
}
|
||||
|
||||
override fun toColumnName(name: String): String {
|
||||
return if (!config.legacyRawTablesOnly) {
|
||||
name.toSnowflakeCompatibleName()
|
||||
} else {
|
||||
// In legacy mode, column names are not transformed
|
||||
name
|
||||
}
|
||||
}
|
||||
|
||||
override fun toColumnType(fieldType: FieldType): ColumnType {
|
||||
val snowflakeType =
|
||||
when (fieldType.type) {
|
||||
// Simple types
|
||||
BooleanType -> SnowflakeDataType.BOOLEAN.typeName
|
||||
IntegerType -> SnowflakeDataType.NUMBER.typeName
|
||||
NumberType -> SnowflakeDataType.FLOAT.typeName
|
||||
StringType -> SnowflakeDataType.VARCHAR.typeName
|
||||
|
||||
// Temporal types
|
||||
DateType -> SnowflakeDataType.DATE.typeName
|
||||
TimeTypeWithTimezone -> SnowflakeDataType.VARCHAR.typeName
|
||||
TimeTypeWithoutTimezone -> SnowflakeDataType.TIME.typeName
|
||||
TimestampTypeWithTimezone -> SnowflakeDataType.TIMESTAMP_TZ.typeName
|
||||
TimestampTypeWithoutTimezone -> SnowflakeDataType.TIMESTAMP_NTZ.typeName
|
||||
|
||||
// Semistructured types
|
||||
is ArrayType,
|
||||
ArrayTypeWithoutSchema -> SnowflakeDataType.ARRAY.typeName
|
||||
is ObjectType,
|
||||
ObjectTypeWithEmptySchema,
|
||||
ObjectTypeWithoutSchema -> SnowflakeDataType.OBJECT.typeName
|
||||
is UnionType -> SnowflakeDataType.VARIANT.typeName
|
||||
is UnknownType -> SnowflakeDataType.VARIANT.typeName
|
||||
}
|
||||
|
||||
return ColumnType(snowflakeType, fieldType.nullable)
|
||||
}
|
||||
|
||||
override fun toFinalSchema(tableSchema: StreamTableSchema): StreamTableSchema {
|
||||
if (!config.legacyRawTablesOnly) {
|
||||
return tableSchema
|
||||
}
|
||||
|
||||
return StreamTableSchema(
|
||||
tableNames = tableSchema.tableNames,
|
||||
columnSchema =
|
||||
tableSchema.columnSchema.copy(
|
||||
finalSchema =
|
||||
mapOf(
|
||||
Meta.COLUMN_NAME_DATA to
|
||||
ColumnType(SnowflakeDataType.OBJECT.typeName, false)
|
||||
)
|
||||
),
|
||||
importType = tableSchema.importType,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,201 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.sql
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting
|
||||
import io.airbyte.cdk.load.data.AirbyteType
|
||||
import io.airbyte.cdk.load.data.ArrayType
|
||||
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
|
||||
import io.airbyte.cdk.load.data.BooleanType
|
||||
import io.airbyte.cdk.load.data.DateType
|
||||
import io.airbyte.cdk.load.data.FieldType
|
||||
import io.airbyte.cdk.load.data.IntegerType
|
||||
import io.airbyte.cdk.load.data.NumberType
|
||||
import io.airbyte.cdk.load.data.ObjectType
|
||||
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
|
||||
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
|
||||
import io.airbyte.cdk.load.data.StringType
|
||||
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
|
||||
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
|
||||
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
|
||||
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
|
||||
import io.airbyte.cdk.load.data.UnionType
|
||||
import io.airbyte.cdk.load.data.UnknownType
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA
|
||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
||||
import io.airbyte.integrations.destination.snowflake.db.SnowflakeColumnNameGenerator
|
||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import jakarta.inject.Singleton
|
||||
import kotlin.collections.component1
|
||||
import kotlin.collections.component2
|
||||
import kotlin.collections.joinToString
|
||||
import kotlin.collections.map
|
||||
import kotlin.collections.plus
|
||||
|
||||
internal const val NOT_NULL = "NOT NULL"
|
||||
|
||||
internal val DEFAULT_COLUMNS =
|
||||
listOf(
|
||||
ColumnAndType(
|
||||
columnName = COLUMN_NAME_AB_RAW_ID,
|
||||
columnType = "${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL"
|
||||
),
|
||||
ColumnAndType(
|
||||
columnName = COLUMN_NAME_AB_EXTRACTED_AT,
|
||||
columnType = "${SnowflakeDataType.TIMESTAMP_TZ.typeName} $NOT_NULL"
|
||||
),
|
||||
ColumnAndType(
|
||||
columnName = COLUMN_NAME_AB_META,
|
||||
columnType = "${SnowflakeDataType.VARIANT.typeName} $NOT_NULL"
|
||||
),
|
||||
ColumnAndType(
|
||||
columnName = COLUMN_NAME_AB_GENERATION_ID,
|
||||
columnType = SnowflakeDataType.NUMBER.typeName
|
||||
),
|
||||
)
|
||||
|
||||
internal val RAW_DATA_COLUMN =
|
||||
ColumnAndType(
|
||||
columnName = COLUMN_NAME_DATA,
|
||||
columnType = "${SnowflakeDataType.VARIANT.typeName} $NOT_NULL"
|
||||
)
|
||||
|
||||
internal val RAW_COLUMNS =
|
||||
listOf(
|
||||
ColumnAndType(
|
||||
columnName = COLUMN_NAME_AB_LOADED_AT,
|
||||
columnType = SnowflakeDataType.TIMESTAMP_TZ.typeName
|
||||
),
|
||||
RAW_DATA_COLUMN
|
||||
)
|
||||
|
||||
@Singleton
|
||||
class SnowflakeColumnUtils(
|
||||
private val snowflakeConfiguration: SnowflakeConfiguration,
|
||||
private val snowflakeColumnNameGenerator: SnowflakeColumnNameGenerator,
|
||||
) {
|
||||
|
||||
@VisibleForTesting
|
||||
internal fun defaultColumns(): List<ColumnAndType> =
|
||||
if (snowflakeConfiguration.legacyRawTablesOnly) {
|
||||
DEFAULT_COLUMNS + RAW_COLUMNS
|
||||
} else {
|
||||
DEFAULT_COLUMNS
|
||||
}
|
||||
|
||||
internal fun formattedDefaultColumns(): List<ColumnAndType> =
|
||||
defaultColumns().map {
|
||||
ColumnAndType(
|
||||
columnName = formatColumnName(it.columnName, false),
|
||||
columnType = it.columnType,
|
||||
)
|
||||
}
|
||||
|
||||
fun getGenerationIdColumnName(): String {
|
||||
return if (snowflakeConfiguration.legacyRawTablesOnly) {
|
||||
COLUMN_NAME_AB_GENERATION_ID
|
||||
} else {
|
||||
COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName()
|
||||
}
|
||||
}
|
||||
|
||||
fun getColumnNames(columnNameMapping: ColumnNameMapping): String =
|
||||
if (snowflakeConfiguration.legacyRawTablesOnly) {
|
||||
getFormattedDefaultColumnNames(true).joinToString(",")
|
||||
} else {
|
||||
(getFormattedDefaultColumnNames(true) +
|
||||
columnNameMapping.map { (_, actualName) -> actualName.quote() })
|
||||
.joinToString(",")
|
||||
}
|
||||
|
||||
fun getFormattedDefaultColumnNames(quote: Boolean = false): List<String> =
|
||||
defaultColumns().map { formatColumnName(it.columnName, quote) }
|
||||
|
||||
fun getFormattedColumnNames(
|
||||
columns: Map<String, FieldType>,
|
||||
columnNameMapping: ColumnNameMapping,
|
||||
quote: Boolean = true,
|
||||
): List<String> =
|
||||
if (snowflakeConfiguration.legacyRawTablesOnly) {
|
||||
getFormattedDefaultColumnNames(quote)
|
||||
} else {
|
||||
getFormattedDefaultColumnNames(quote) +
|
||||
columns.map { (fieldName, _) ->
|
||||
val columnName = columnNameMapping[fieldName] ?: fieldName
|
||||
if (quote) columnName.quote() else columnName
|
||||
}
|
||||
}
|
||||
|
||||
fun columnsAndTypes(
|
||||
columns: Map<String, FieldType>,
|
||||
columnNameMapping: ColumnNameMapping
|
||||
): List<ColumnAndType> =
|
||||
if (snowflakeConfiguration.legacyRawTablesOnly) {
|
||||
formattedDefaultColumns()
|
||||
} else {
|
||||
formattedDefaultColumns() +
|
||||
columns.map { (fieldName, type) ->
|
||||
val columnName = columnNameMapping[fieldName] ?: fieldName
|
||||
val typeName = toDialectType(type.type)
|
||||
ColumnAndType(
|
||||
columnName = columnName,
|
||||
columnType = if (type.nullable) typeName else "$typeName $NOT_NULL",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fun formatColumnName(
|
||||
columnName: String,
|
||||
quote: Boolean = true,
|
||||
): String {
|
||||
val formattedColumnName =
|
||||
if (columnName == COLUMN_NAME_DATA) columnName
|
||||
else snowflakeColumnNameGenerator.getColumnName(columnName).displayName
|
||||
return if (quote) formattedColumnName.quote() else formattedColumnName
|
||||
}
|
||||
|
||||
fun toDialectType(type: AirbyteType): String =
|
||||
when (type) {
|
||||
// Simple types
|
||||
BooleanType -> SnowflakeDataType.BOOLEAN.typeName
|
||||
IntegerType -> SnowflakeDataType.NUMBER.typeName
|
||||
NumberType -> SnowflakeDataType.FLOAT.typeName
|
||||
StringType -> SnowflakeDataType.VARCHAR.typeName
|
||||
|
||||
// Temporal types
|
||||
DateType -> SnowflakeDataType.DATE.typeName
|
||||
TimeTypeWithTimezone -> SnowflakeDataType.VARCHAR.typeName
|
||||
TimeTypeWithoutTimezone -> SnowflakeDataType.TIME.typeName
|
||||
TimestampTypeWithTimezone -> SnowflakeDataType.TIMESTAMP_TZ.typeName
|
||||
TimestampTypeWithoutTimezone -> SnowflakeDataType.TIMESTAMP_NTZ.typeName
|
||||
|
||||
// Semistructured types
|
||||
is ArrayType,
|
||||
ArrayTypeWithoutSchema -> SnowflakeDataType.ARRAY.typeName
|
||||
is ObjectType,
|
||||
ObjectTypeWithEmptySchema,
|
||||
ObjectTypeWithoutSchema -> SnowflakeDataType.OBJECT.typeName
|
||||
is UnionType -> SnowflakeDataType.VARIANT.typeName
|
||||
is UnknownType -> SnowflakeDataType.VARIANT.typeName
|
||||
}
|
||||
}
|
||||
|
||||
data class ColumnAndType(val columnName: String, val columnType: String) {
|
||||
override fun toString(): String {
|
||||
return "${columnName.quote()} $columnType"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Surrounds the string instance with double quotation marks (e.g. "some string" -> "\"some
|
||||
* string\"").
|
||||
*/
|
||||
fun String.quote() = "$QUOTE$this$QUOTE"
|
||||
@@ -10,7 +10,7 @@ package io.airbyte.integrations.destination.snowflake.sql
|
||||
*/
|
||||
enum class SnowflakeDataType(val typeName: String) {
|
||||
// Numeric types
|
||||
NUMBER("NUMBER(38,0)"),
|
||||
NUMBER("NUMBER"),
|
||||
FLOAT("FLOAT"),
|
||||
|
||||
// String & binary types
|
||||
|
||||
@@ -4,16 +4,19 @@
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.sql
|
||||
|
||||
import io.airbyte.cdk.load.command.Dedupe
|
||||
import io.airbyte.cdk.load.command.DestinationStream
|
||||
import com.google.common.annotations.VisibleForTesting
|
||||
import io.airbyte.cdk.load.component.ColumnType
|
||||
import io.airbyte.cdk.load.component.ColumnTypeChange
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
|
||||
import io.airbyte.cdk.load.schema.model.StreamTableSchema
|
||||
import io.airbyte.cdk.load.schema.model.TableName
|
||||
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
|
||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
||||
import io.airbyte.cdk.load.table.TableName
|
||||
import io.airbyte.cdk.load.util.UUIDGenerator
|
||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.write.load.CSV_FIELD_SEPARATOR
|
||||
@@ -22,6 +25,14 @@ import io.github.oshai.kotlinlogging.KotlinLogging
|
||||
import jakarta.inject.Singleton
|
||||
|
||||
internal const val COUNT_TOTAL_ALIAS = "TOTAL"
|
||||
internal const val NOT_NULL = "NOT NULL"
|
||||
|
||||
// Snowflake-compatible (uppercase) versions of the Airbyte meta column names
|
||||
internal val SNOWFLAKE_AB_RAW_ID = COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName()
|
||||
internal val SNOWFLAKE_AB_EXTRACTED_AT = COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()
|
||||
internal val SNOWFLAKE_AB_META = COLUMN_NAME_AB_META.toSnowflakeCompatibleName()
|
||||
internal val SNOWFLAKE_AB_GENERATION_ID = COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName()
|
||||
internal val SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN = CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()
|
||||
|
||||
private val log = KotlinLogging.logger {}
|
||||
|
||||
@@ -36,80 +47,91 @@ fun String.andLog(): String {
|
||||
|
||||
@Singleton
|
||||
class SnowflakeDirectLoadSqlGenerator(
|
||||
private val columnUtils: SnowflakeColumnUtils,
|
||||
private val uuidGenerator: UUIDGenerator,
|
||||
private val snowflakeConfiguration: SnowflakeConfiguration,
|
||||
private val snowflakeSqlNameUtils: SnowflakeSqlNameUtils,
|
||||
private val config: SnowflakeConfiguration,
|
||||
private val columnManager: SnowflakeColumnManager,
|
||||
) {
|
||||
fun countTable(tableName: TableName): String {
|
||||
return "SELECT COUNT(*) AS $COUNT_TOTAL_ALIAS FROM ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog()
|
||||
return "SELECT COUNT(*) AS $COUNT_TOTAL_ALIAS FROM ${fullyQualifiedName(tableName)}".andLog()
|
||||
}
|
||||
|
||||
fun createNamespace(namespace: String): String {
|
||||
return "CREATE SCHEMA IF NOT EXISTS ${snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)}".andLog()
|
||||
return "CREATE SCHEMA IF NOT EXISTS ${fullyQualifiedNamespace(namespace)}".andLog()
|
||||
}
|
||||
|
||||
fun createTable(
|
||||
stream: DestinationStream,
|
||||
tableName: TableName,
|
||||
columnNameMapping: ColumnNameMapping,
|
||||
tableSchema: StreamTableSchema,
|
||||
replace: Boolean
|
||||
): String {
|
||||
val finalSchema = tableSchema.columnSchema.finalSchema
|
||||
val metaColumns = columnManager.getMetaColumns()
|
||||
|
||||
// Build column declarations from the meta columns and user schema
|
||||
val columnDeclarations =
|
||||
columnUtils
|
||||
.columnsAndTypes(stream.schema.asColumns(), columnNameMapping)
|
||||
.joinToString(",\n")
|
||||
buildList {
|
||||
// Add Airbyte meta columns from the column manager
|
||||
metaColumns.forEach { (columnName, columnType) ->
|
||||
val nullability = if (columnType.nullable) "" else " NOT NULL"
|
||||
add("${columnName.quote()} ${columnType.type}$nullability")
|
||||
}
|
||||
|
||||
// Add user columns from the munged schema
|
||||
finalSchema.forEach { (columnName, columnType) ->
|
||||
val nullability = if (columnType.nullable) "" else " NOT NULL"
|
||||
add("${columnName.quote()} ${columnType.type}$nullability")
|
||||
}
|
||||
}
|
||||
.joinToString(",\n ")
|
||||
|
||||
// Snowflake supports CREATE OR REPLACE TABLE, which is simpler than drop+recreate
|
||||
val createOrReplace = if (replace) "CREATE OR REPLACE" else "CREATE"
|
||||
|
||||
val createTableStatement =
|
||||
"""
|
||||
$createOrReplace TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)} (
|
||||
$columnDeclarations
|
||||
)
|
||||
""".trimIndent()
|
||||
|$createOrReplace TABLE ${fullyQualifiedName(tableName)} (
|
||||
| $columnDeclarations
|
||||
|)
|
||||
""".trimMargin() // Something was tripping up trimIndent so we opt for trimMargin
|
||||
|
||||
return createTableStatement.andLog()
|
||||
}
|
||||
|
||||
fun showColumns(tableName: TableName): String =
|
||||
"SHOW COLUMNS IN TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog()
|
||||
"SHOW COLUMNS IN TABLE ${fullyQualifiedName(tableName)}".andLog()
|
||||
|
||||
fun copyTable(
|
||||
columnNameMapping: ColumnNameMapping,
|
||||
columnNames: Set<String>,
|
||||
sourceTableName: TableName,
|
||||
targetTableName: TableName
|
||||
): String {
|
||||
val columnNames = columnUtils.getColumnNames(columnNameMapping)
|
||||
val columnList = columnNames.joinToString(", ") { it.quote() }
|
||||
|
||||
return """
|
||||
INSERT INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)}
|
||||
INSERT INTO ${fullyQualifiedName(targetTableName)}
|
||||
(
|
||||
$columnNames
|
||||
$columnList
|
||||
)
|
||||
SELECT
|
||||
$columnNames
|
||||
FROM ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)}
|
||||
$columnList
|
||||
FROM ${fullyQualifiedName(sourceTableName)}
|
||||
"""
|
||||
.trimIndent()
|
||||
.andLog()
|
||||
}
|
||||
|
||||
fun upsertTable(
|
||||
stream: DestinationStream,
|
||||
columnNameMapping: ColumnNameMapping,
|
||||
tableSchema: StreamTableSchema,
|
||||
sourceTableName: TableName,
|
||||
targetTableName: TableName
|
||||
): String {
|
||||
val importType = stream.importType as Dedupe
|
||||
val finalSchema = tableSchema.columnSchema.finalSchema
|
||||
|
||||
// Build primary key matching condition
|
||||
val pks = tableSchema.getPrimaryKey().flatten()
|
||||
val pkEquivalent =
|
||||
if (importType.primaryKey.isNotEmpty()) {
|
||||
importType.primaryKey.joinToString(" AND ") { fieldPath ->
|
||||
val fieldName = fieldPath.first()
|
||||
val columnName = columnNameMapping[fieldName] ?: fieldName
|
||||
if (pks.isNotEmpty()) {
|
||||
pks.joinToString(" AND ") { columnName ->
|
||||
val targetTableColumnName = "target_table.${columnName.quote()}"
|
||||
val newRecordColumnName = "new_record.${columnName.quote()}"
|
||||
"""($targetTableColumnName = $newRecordColumnName OR ($targetTableColumnName IS NULL AND $newRecordColumnName IS NULL))"""
|
||||
@@ -120,80 +142,62 @@ class SnowflakeDirectLoadSqlGenerator(
|
||||
}
|
||||
|
||||
// Build column lists for INSERT and UPDATE
|
||||
val columnList: String =
|
||||
columnUtils
|
||||
.getFormattedColumnNames(
|
||||
columns = stream.schema.asColumns(),
|
||||
columnNameMapping = columnNameMapping,
|
||||
quote = false,
|
||||
)
|
||||
.joinToString(
|
||||
",\n",
|
||||
) {
|
||||
it.quote()
|
||||
}
|
||||
val allColumns = buildList {
|
||||
add(SNOWFLAKE_AB_RAW_ID)
|
||||
add(SNOWFLAKE_AB_EXTRACTED_AT)
|
||||
add(SNOWFLAKE_AB_META)
|
||||
add(SNOWFLAKE_AB_GENERATION_ID)
|
||||
addAll(finalSchema.keys)
|
||||
}
|
||||
|
||||
val columnList: String = allColumns.joinToString(",\n ") { it.quote() }
|
||||
val newRecordColumnList: String =
|
||||
columnUtils
|
||||
.getFormattedColumnNames(
|
||||
columns = stream.schema.asColumns(),
|
||||
columnNameMapping = columnNameMapping,
|
||||
quote = false,
|
||||
)
|
||||
.joinToString(",\n") { "new_record.${it.quote()}" }
|
||||
allColumns.joinToString(",\n ") { "new_record.${it.quote()}" }
|
||||
|
||||
// Get deduped records from source
|
||||
val selectSourceRecords = selectDedupedRecords(stream, sourceTableName, columnNameMapping)
|
||||
val selectSourceRecords = selectDedupedRecords(tableSchema, sourceTableName)
|
||||
|
||||
// Build cursor comparison for determining which record is newer
|
||||
val cursorComparison: String
|
||||
if (importType.cursor.isNotEmpty()) {
|
||||
val cursorFieldName = importType.cursor.first()
|
||||
val cursor = (columnNameMapping[cursorFieldName] ?: cursorFieldName)
|
||||
val cursor = tableSchema.getCursor().firstOrNull()
|
||||
if (cursor != null) {
|
||||
val targetTableCursor = "target_table.${cursor.quote()}"
|
||||
val newRecordCursor = "new_record.${cursor.quote()}"
|
||||
cursorComparison =
|
||||
"""
|
||||
(
|
||||
$targetTableCursor < $newRecordCursor
|
||||
OR ($targetTableCursor = $newRecordCursor AND target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}")
|
||||
OR ($targetTableCursor IS NULL AND $newRecordCursor IS NULL AND target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}")
|
||||
OR ($targetTableCursor = $newRecordCursor AND target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT")
|
||||
OR ($targetTableCursor IS NULL AND $newRecordCursor IS NULL AND target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT")
|
||||
OR ($targetTableCursor IS NULL AND $newRecordCursor IS $NOT_NULL)
|
||||
)
|
||||
""".trimIndent()
|
||||
} else {
|
||||
// No cursor - use extraction timestamp only
|
||||
cursorComparison =
|
||||
"""target_table."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" < new_record."${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}""""
|
||||
"""target_table."$SNOWFLAKE_AB_EXTRACTED_AT" < new_record."$SNOWFLAKE_AB_EXTRACTED_AT""""
|
||||
}
|
||||
|
||||
// Build column assignments for UPDATE
|
||||
val columnAssignments: String =
|
||||
columnUtils
|
||||
.getFormattedColumnNames(
|
||||
columns = stream.schema.asColumns(),
|
||||
columnNameMapping = columnNameMapping,
|
||||
quote = false,
|
||||
)
|
||||
.joinToString(",\n") { column ->
|
||||
"${column.quote()} = new_record.${column.quote()}"
|
||||
}
|
||||
allColumns.joinToString(",\n ") { column ->
|
||||
"${column.quote()} = new_record.${column.quote()}"
|
||||
}
|
||||
|
||||
// Handle CDC deletions based on mode
|
||||
val cdcDeleteClause: String
|
||||
val cdcSkipInsertClause: String
|
||||
if (
|
||||
stream.schema.asColumns().containsKey(CDC_DELETED_AT_COLUMN) &&
|
||||
snowflakeConfiguration.cdcDeletionMode == CdcDeletionMode.HARD_DELETE
|
||||
finalSchema.containsKey(SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN) &&
|
||||
config.cdcDeletionMode == CdcDeletionMode.HARD_DELETE
|
||||
) {
|
||||
// Execute CDC deletions if there's already a record
|
||||
cdcDeleteClause =
|
||||
"WHEN MATCHED AND new_record.\"${CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()}\" IS NOT NULL AND $cursorComparison THEN DELETE"
|
||||
"WHEN MATCHED AND new_record.\"${SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN}\" IS NOT NULL AND $cursorComparison THEN DELETE"
|
||||
// And skip insertion entirely if there's no matching record.
|
||||
// (This is possible if a single T+D batch contains both an insertion and deletion for
|
||||
// the same PK)
|
||||
cdcSkipInsertClause =
|
||||
"AND new_record.\"${CDC_DELETED_AT_COLUMN.toSnowflakeCompatibleName()}\" IS NULL"
|
||||
cdcSkipInsertClause = "AND new_record.\"${SNOWFLAKE_AB_CDC_DELETED_AT_COLUMN}\" IS NULL"
|
||||
} else {
|
||||
cdcDeleteClause = ""
|
||||
cdcSkipInsertClause = ""
|
||||
@@ -203,35 +207,35 @@ class SnowflakeDirectLoadSqlGenerator(
|
||||
val mergeStatement =
|
||||
if (cdcDeleteClause.isNotEmpty()) {
|
||||
"""
|
||||
MERGE INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} AS target_table
|
||||
USING (
|
||||
$selectSourceRecords
|
||||
) AS new_record
|
||||
ON $pkEquivalent
|
||||
$cdcDeleteClause
|
||||
WHEN MATCHED AND $cursorComparison THEN UPDATE SET
|
||||
$columnAssignments
|
||||
WHEN NOT MATCHED $cdcSkipInsertClause THEN INSERT (
|
||||
$columnList
|
||||
) VALUES (
|
||||
$newRecordColumnList
|
||||
)
|
||||
""".trimIndent()
|
||||
|MERGE INTO ${fullyQualifiedName(targetTableName)} AS target_table
|
||||
|USING (
|
||||
|$selectSourceRecords
|
||||
|) AS new_record
|
||||
|ON $pkEquivalent
|
||||
|$cdcDeleteClause
|
||||
|WHEN MATCHED AND $cursorComparison THEN UPDATE SET
|
||||
| $columnAssignments
|
||||
|WHEN NOT MATCHED $cdcSkipInsertClause THEN INSERT (
|
||||
| $columnList
|
||||
|) VALUES (
|
||||
| $newRecordColumnList
|
||||
|)
|
||||
""".trimMargin()
|
||||
} else {
|
||||
"""
|
||||
MERGE INTO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)} AS target_table
|
||||
USING (
|
||||
$selectSourceRecords
|
||||
) AS new_record
|
||||
ON $pkEquivalent
|
||||
WHEN MATCHED AND $cursorComparison THEN UPDATE SET
|
||||
$columnAssignments
|
||||
WHEN NOT MATCHED THEN INSERT (
|
||||
$columnList
|
||||
) VALUES (
|
||||
$newRecordColumnList
|
||||
)
|
||||
""".trimIndent()
|
||||
|MERGE INTO ${fullyQualifiedName(targetTableName)} AS target_table
|
||||
|USING (
|
||||
|$selectSourceRecords
|
||||
|) AS new_record
|
||||
|ON $pkEquivalent
|
||||
|WHEN MATCHED AND $cursorComparison THEN UPDATE SET
|
||||
| $columnAssignments
|
||||
|WHEN NOT MATCHED THEN INSERT (
|
||||
| $columnList
|
||||
|) VALUES (
|
||||
| $newRecordColumnList
|
||||
|)
|
||||
""".trimMargin()
|
||||
}
|
||||
|
||||
return mergeStatement.andLog()
|
||||
@@ -242,75 +246,66 @@ class SnowflakeDirectLoadSqlGenerator(
|
||||
* table. Uses ROW_NUMBER() window function to select the most recent record per primary key.
|
||||
*/
|
||||
private fun selectDedupedRecords(
|
||||
stream: DestinationStream,
|
||||
sourceTableName: TableName,
|
||||
columnNameMapping: ColumnNameMapping
|
||||
tableSchema: StreamTableSchema,
|
||||
sourceTableName: TableName
|
||||
): String {
|
||||
val columnList: String =
|
||||
columnUtils
|
||||
.getFormattedColumnNames(
|
||||
columns = stream.schema.asColumns(),
|
||||
columnNameMapping = columnNameMapping,
|
||||
quote = false,
|
||||
)
|
||||
.joinToString(
|
||||
",\n",
|
||||
) {
|
||||
it.quote()
|
||||
}
|
||||
val importType = stream.importType as Dedupe
|
||||
val allColumns = buildList {
|
||||
add(SNOWFLAKE_AB_RAW_ID)
|
||||
add(SNOWFLAKE_AB_EXTRACTED_AT)
|
||||
add(SNOWFLAKE_AB_META)
|
||||
add(SNOWFLAKE_AB_GENERATION_ID)
|
||||
addAll(tableSchema.columnSchema.finalSchema.keys)
|
||||
}
|
||||
val columnList: String = allColumns.joinToString(",\n ") { it.quote() }
|
||||
|
||||
// Build the primary key list for partitioning
|
||||
val pks = tableSchema.getPrimaryKey().flatten()
|
||||
val pkList =
|
||||
if (importType.primaryKey.isNotEmpty()) {
|
||||
importType.primaryKey.joinToString(",") { fieldPath ->
|
||||
(columnNameMapping[fieldPath.first()] ?: fieldPath.first()).quote()
|
||||
}
|
||||
if (pks.isNotEmpty()) {
|
||||
pks.joinToString(",") { it.quote() }
|
||||
} else {
|
||||
// Should not happen as we check this earlier, but handle it defensively
|
||||
throw IllegalArgumentException("Cannot deduplicate without primary key")
|
||||
}
|
||||
|
||||
// Build cursor order clause for sorting within each partition
|
||||
val cursor = tableSchema.getCursor().firstOrNull()
|
||||
val cursorOrderClause =
|
||||
if (importType.cursor.isNotEmpty()) {
|
||||
val columnName =
|
||||
(columnNameMapping[importType.cursor.first()] ?: importType.cursor.first())
|
||||
.quote()
|
||||
"$columnName DESC NULLS LAST,"
|
||||
if (cursor != null) {
|
||||
"${cursor.quote()} DESC NULLS LAST,"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
|
||||
return """
|
||||
WITH records AS (
|
||||
SELECT
|
||||
$columnList
|
||||
FROM ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)}
|
||||
), numbered_rows AS (
|
||||
SELECT *, ROW_NUMBER() OVER (
|
||||
PARTITION BY $pkList ORDER BY $cursorOrderClause "${COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName()}" DESC
|
||||
) AS row_number
|
||||
FROM records
|
||||
)
|
||||
SELECT $columnList
|
||||
FROM numbered_rows
|
||||
WHERE row_number = 1
|
||||
| WITH records AS (
|
||||
| SELECT
|
||||
| $columnList
|
||||
| FROM ${fullyQualifiedName(sourceTableName)}
|
||||
| ), numbered_rows AS (
|
||||
| SELECT *, ROW_NUMBER() OVER (
|
||||
| PARTITION BY $pkList ORDER BY $cursorOrderClause "$SNOWFLAKE_AB_EXTRACTED_AT" DESC
|
||||
| ) AS row_number
|
||||
| FROM records
|
||||
| )
|
||||
| SELECT $columnList
|
||||
| FROM numbered_rows
|
||||
| WHERE row_number = 1
|
||||
"""
|
||||
.trimIndent()
|
||||
.trimMargin()
|
||||
.andLog()
|
||||
}
|
||||
|
||||
fun dropTable(tableName: TableName): String {
|
||||
return "DROP TABLE IF EXISTS ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}".andLog()
|
||||
return "DROP TABLE IF EXISTS ${fullyQualifiedName(tableName)}".andLog()
|
||||
}
|
||||
|
||||
fun getGenerationId(
|
||||
tableName: TableName,
|
||||
): String {
|
||||
return """
|
||||
SELECT "${columnUtils.getGenerationIdColumnName()}"
|
||||
FROM ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}
|
||||
SELECT "${columnManager.getGenerationIdColumnName()}"
|
||||
FROM ${fullyQualifiedName(tableName)}
|
||||
LIMIT 1
|
||||
"""
|
||||
.trimIndent()
|
||||
@@ -318,12 +313,12 @@ class SnowflakeDirectLoadSqlGenerator(
|
||||
}
|
||||
|
||||
fun createSnowflakeStage(tableName: TableName): String {
|
||||
val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName)
|
||||
val stageName = fullyQualifiedStageName(tableName)
|
||||
return "CREATE STAGE IF NOT EXISTS $stageName".andLog()
|
||||
}
|
||||
|
||||
fun putInStage(tableName: TableName, tempFilePath: String): String {
|
||||
val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true)
|
||||
val stageName = fullyQualifiedStageName(tableName, true)
|
||||
return """
|
||||
PUT 'file://$tempFilePath' '@$stageName'
|
||||
AUTO_COMPRESS = FALSE
|
||||
@@ -334,35 +329,45 @@ class SnowflakeDirectLoadSqlGenerator(
|
||||
.andLog()
|
||||
}
|
||||
|
||||
fun copyFromStage(tableName: TableName, filename: String): String {
|
||||
val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true)
|
||||
fun copyFromStage(
|
||||
tableName: TableName,
|
||||
filename: String,
|
||||
columnNames: List<String>? = null
|
||||
): String {
|
||||
val stageName = fullyQualifiedStageName(tableName, true)
|
||||
val columnList =
|
||||
columnNames?.let { names -> "(${names.joinToString(", ") { it.quote() }})" } ?: ""
|
||||
|
||||
return """
|
||||
COPY INTO ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}
|
||||
FROM '@$stageName'
|
||||
FILE_FORMAT = (
|
||||
TYPE = 'CSV'
|
||||
COMPRESSION = GZIP
|
||||
FIELD_DELIMITER = '$CSV_FIELD_SEPARATOR'
|
||||
RECORD_DELIMITER = '$CSV_LINE_DELIMITER'
|
||||
FIELD_OPTIONALLY_ENCLOSED_BY = '"'
|
||||
TRIM_SPACE = TRUE
|
||||
ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE
|
||||
REPLACE_INVALID_CHARACTERS = TRUE
|
||||
ESCAPE = NONE
|
||||
ESCAPE_UNENCLOSED_FIELD = NONE
|
||||
)
|
||||
ON_ERROR = 'ABORT_STATEMENT'
|
||||
PURGE = TRUE
|
||||
files = ('$filename')
|
||||
|COPY INTO ${fullyQualifiedName(tableName)}$columnList
|
||||
|FROM '@$stageName'
|
||||
|FILE_FORMAT = (
|
||||
| TYPE = 'CSV'
|
||||
| COMPRESSION = GZIP
|
||||
| FIELD_DELIMITER = '$CSV_FIELD_SEPARATOR'
|
||||
| RECORD_DELIMITER = '$CSV_LINE_DELIMITER'
|
||||
| FIELD_OPTIONALLY_ENCLOSED_BY = '"'
|
||||
| TRIM_SPACE = TRUE
|
||||
| ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE
|
||||
| REPLACE_INVALID_CHARACTERS = TRUE
|
||||
| ESCAPE = NONE
|
||||
| ESCAPE_UNENCLOSED_FIELD = NONE
|
||||
|)
|
||||
|ON_ERROR = 'ABORT_STATEMENT'
|
||||
|PURGE = TRUE
|
||||
|files = ('$filename')
|
||||
"""
|
||||
.trimIndent()
|
||||
.trimMargin()
|
||||
.andLog()
|
||||
}
|
||||
|
||||
fun swapTableWith(sourceTableName: TableName, targetTableName: TableName): String {
|
||||
return """
|
||||
ALTER TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} SWAP WITH ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)}
|
||||
ALTER TABLE ${fullyQualifiedName(sourceTableName)} SWAP WITH ${
|
||||
fullyQualifiedName(
|
||||
targetTableName,
|
||||
)
|
||||
}
|
||||
"""
|
||||
.trimIndent()
|
||||
.andLog()
|
||||
@@ -372,7 +377,11 @@ class SnowflakeDirectLoadSqlGenerator(
|
||||
// Snowflake RENAME TO only accepts the table name, not a fully qualified name
|
||||
// The renamed table stays in the same schema
|
||||
return """
|
||||
ALTER TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(sourceTableName)} RENAME TO ${snowflakeSqlNameUtils.fullyQualifiedName(targetTableName)}
|
||||
ALTER TABLE ${fullyQualifiedName(sourceTableName)} RENAME TO ${
|
||||
fullyQualifiedName(
|
||||
targetTableName,
|
||||
)
|
||||
}
|
||||
"""
|
||||
.trimIndent()
|
||||
.andLog()
|
||||
@@ -382,7 +391,7 @@ class SnowflakeDirectLoadSqlGenerator(
|
||||
schemaName: String,
|
||||
tableName: String,
|
||||
): String =
|
||||
"""DESCRIBE TABLE ${snowflakeSqlNameUtils.fullyQualifiedName(TableName(schemaName, tableName))}""".andLog()
|
||||
"""DESCRIBE TABLE ${fullyQualifiedName(TableName(schemaName, tableName))}""".andLog()
|
||||
|
||||
fun alterTable(
|
||||
tableName: TableName,
|
||||
@@ -391,14 +400,14 @@ class SnowflakeDirectLoadSqlGenerator(
|
||||
modifiedColumns: Map<String, ColumnTypeChange>,
|
||||
): Set<String> {
|
||||
val clauses = mutableSetOf<String>()
|
||||
val prettyTableName = snowflakeSqlNameUtils.fullyQualifiedName(tableName)
|
||||
val prettyTableName = fullyQualifiedName(tableName)
|
||||
addedColumns.forEach { (name, columnType) ->
|
||||
clauses.add(
|
||||
// Note that we intentionally don't set NOT NULL.
|
||||
// We're adding a new column, and we don't know what constitutes a reasonable
|
||||
// default value for preexisting records.
|
||||
// So we add the column as nullable.
|
||||
"ALTER TABLE $prettyTableName ADD COLUMN ${name.quote()} ${columnType.type};".andLog()
|
||||
"ALTER TABLE $prettyTableName ADD COLUMN ${name.quote()} ${columnType.type};".andLog(),
|
||||
)
|
||||
}
|
||||
deletedColumns.forEach {
|
||||
@@ -412,35 +421,34 @@ class SnowflakeDirectLoadSqlGenerator(
|
||||
val tempColumn = "${name}_${uuidGenerator.v4()}"
|
||||
clauses.add(
|
||||
// As above: we add the column as nullable.
|
||||
"ALTER TABLE $prettyTableName ADD COLUMN ${tempColumn.quote()} ${typeChange.newType.type};".andLog()
|
||||
"ALTER TABLE $prettyTableName ADD COLUMN ${tempColumn.quote()} ${typeChange.newType.type};".andLog(),
|
||||
)
|
||||
clauses.add(
|
||||
"UPDATE $prettyTableName SET ${tempColumn.quote()} = CAST(${name.quote()} AS ${typeChange.newType.type});".andLog()
|
||||
"UPDATE $prettyTableName SET ${tempColumn.quote()} = CAST(${name.quote()} AS ${typeChange.newType.type});".andLog(),
|
||||
)
|
||||
val backupColumn = "${tempColumn}_backup"
|
||||
clauses.add(
|
||||
"""
|
||||
ALTER TABLE $prettyTableName
|
||||
RENAME COLUMN "$name" TO "$backupColumn";
|
||||
""".trimIndent()
|
||||
""".trimIndent(),
|
||||
)
|
||||
clauses.add(
|
||||
"""
|
||||
ALTER TABLE $prettyTableName
|
||||
RENAME COLUMN "$tempColumn" TO "$name";
|
||||
""".trimIndent()
|
||||
""".trimIndent(),
|
||||
)
|
||||
clauses.add(
|
||||
"ALTER TABLE $prettyTableName DROP COLUMN ${backupColumn.quote()};".andLog()
|
||||
"ALTER TABLE $prettyTableName DROP COLUMN ${backupColumn.quote()};".andLog(),
|
||||
)
|
||||
} else if (!typeChange.originalType.nullable && typeChange.newType.nullable) {
|
||||
// If the type is unchanged, we can change a column from NOT NULL to nullable.
|
||||
// But we'll never do the reverse, because there's a decent chance that historical
|
||||
// records
|
||||
// had null values.
|
||||
// records had null values.
|
||||
// Users can always manually ALTER COLUMN ... SET NOT NULL if they want.
|
||||
clauses.add(
|
||||
"""ALTER TABLE $prettyTableName ALTER COLUMN "$name" DROP NOT NULL;""".andLog()
|
||||
"""ALTER TABLE $prettyTableName ALTER COLUMN "$name" DROP NOT NULL;""".andLog(),
|
||||
)
|
||||
} else {
|
||||
log.info {
|
||||
@@ -450,4 +458,45 @@ class SnowflakeDirectLoadSqlGenerator(
|
||||
}
|
||||
return clauses
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
fun fullyQualifiedName(tableName: TableName): String =
|
||||
combineParts(listOf(getDatabaseName(), tableName.namespace, tableName.name))
|
||||
|
||||
@VisibleForTesting
|
||||
fun fullyQualifiedNamespace(namespace: String) =
|
||||
combineParts(listOf(getDatabaseName(), namespace))
|
||||
|
||||
@VisibleForTesting
|
||||
fun fullyQualifiedStageName(tableName: TableName, escape: Boolean = false): String {
|
||||
val currentTableName =
|
||||
if (escape) {
|
||||
tableName.name
|
||||
} else {
|
||||
tableName.name
|
||||
}
|
||||
return combineParts(
|
||||
parts =
|
||||
listOf(
|
||||
getDatabaseName(),
|
||||
tableName.namespace,
|
||||
"$STAGE_NAME_PREFIX$currentTableName",
|
||||
),
|
||||
escape = escape,
|
||||
)
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
internal fun combineParts(parts: List<String>, escape: Boolean = false): String =
|
||||
parts
|
||||
.map { if (escape) sqlEscape(it) else it }
|
||||
.joinToString(separator = ".") {
|
||||
if (!it.startsWith(QUOTE)) {
|
||||
"$QUOTE$it$QUOTE"
|
||||
} else {
|
||||
it
|
||||
}
|
||||
}
|
||||
|
||||
private fun getDatabaseName() = config.database.toSnowflakeCompatibleName()
|
||||
}
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.sql
|
||||
|
||||
const val STAGE_NAME_PREFIX = "airbyte_stage_"
|
||||
internal const val QUOTE: String = "\""
|
||||
|
||||
fun sqlEscape(part: String) = part.replace("\\", "\\\\").replace("'", "\\'").replace("\"", "\\\"")
|
||||
|
||||
/**
|
||||
* Surrounds the string instance with double quotation marks (e.g. "some string" -> "\"some
|
||||
* string\"").
|
||||
*/
|
||||
fun String.quote() = "$QUOTE$this$QUOTE"
|
||||
|
||||
/**
|
||||
* Escapes double-quotes in a JSON identifier by doubling them. This is legacy -- I don't know why
|
||||
* this would be necessary but no harm in keeping it, so I am keeping it.
|
||||
*
|
||||
* @return The escaped identifier.
|
||||
*/
|
||||
fun escapeJsonIdentifier(identifier: String): String {
|
||||
// Note that we don't need to escape backslashes here!
|
||||
// The only special character in an identifier is the double-quote, which needs to be
|
||||
// doubled.
|
||||
return identifier.replace(QUOTE, "$QUOTE$QUOTE")
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.sql
|
||||
|
||||
import io.airbyte.cdk.load.table.TableName
|
||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import jakarta.inject.Singleton
|
||||
|
||||
const val STAGE_NAME_PREFIX = "airbyte_stage_"
|
||||
internal const val QUOTE: String = "\""
|
||||
|
||||
fun sqlEscape(part: String) = part.replace("\\", "\\\\").replace("'", "\\'").replace("\"", "\\\"")
|
||||
|
||||
@Singleton
|
||||
class SnowflakeSqlNameUtils(
|
||||
private val snowflakeConfiguration: SnowflakeConfiguration,
|
||||
) {
|
||||
fun fullyQualifiedName(tableName: TableName): String =
|
||||
combineParts(listOf(getDatabaseName(), tableName.namespace, tableName.name))
|
||||
|
||||
fun fullyQualifiedNamespace(namespace: String) =
|
||||
combineParts(listOf(getDatabaseName(), namespace))
|
||||
|
||||
fun fullyQualifiedStageName(tableName: TableName, escape: Boolean = false): String {
|
||||
val currentTableName =
|
||||
if (escape) {
|
||||
tableName.name
|
||||
} else {
|
||||
tableName.name
|
||||
}
|
||||
return combineParts(
|
||||
parts =
|
||||
listOf(
|
||||
getDatabaseName(),
|
||||
tableName.namespace,
|
||||
"$STAGE_NAME_PREFIX$currentTableName"
|
||||
),
|
||||
escape = escape,
|
||||
)
|
||||
}
|
||||
|
||||
fun combineParts(parts: List<String>, escape: Boolean = false): String =
|
||||
parts
|
||||
.map { if (escape) sqlEscape(it) else it }
|
||||
.joinToString(separator = ".") {
|
||||
if (!it.startsWith(QUOTE)) {
|
||||
"$QUOTE$it$QUOTE"
|
||||
} else {
|
||||
it
|
||||
}
|
||||
}
|
||||
|
||||
private fun getDatabaseName() = snowflakeConfiguration.database.toSnowflakeCompatibleName()
|
||||
}
|
||||
@@ -6,38 +6,39 @@ package io.airbyte.integrations.destination.snowflake.write
|
||||
|
||||
import io.airbyte.cdk.SystemErrorException
|
||||
import io.airbyte.cdk.load.command.Dedupe
|
||||
import io.airbyte.cdk.load.command.DestinationCatalog
|
||||
import io.airbyte.cdk.load.command.DestinationStream
|
||||
import io.airbyte.cdk.load.orchestration.db.DatabaseInitialStatusGatherer
|
||||
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator
|
||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadInitialStatus
|
||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendStreamLoader
|
||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendTruncateStreamLoader
|
||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableDedupStreamLoader
|
||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableDedupTruncateStreamLoader
|
||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig
|
||||
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog
|
||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
||||
import io.airbyte.cdk.load.table.DatabaseInitialStatusGatherer
|
||||
import io.airbyte.cdk.load.table.TempTableNameGenerator
|
||||
import io.airbyte.cdk.load.table.directload.DirectLoadInitialStatus
|
||||
import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendStreamLoader
|
||||
import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendTruncateStreamLoader
|
||||
import io.airbyte.cdk.load.table.directload.DirectLoadTableDedupStreamLoader
|
||||
import io.airbyte.cdk.load.table.directload.DirectLoadTableDedupTruncateStreamLoader
|
||||
import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
|
||||
import io.airbyte.cdk.load.write.DestinationWriter
|
||||
import io.airbyte.cdk.load.write.StreamLoader
|
||||
import io.airbyte.cdk.load.write.StreamStateStore
|
||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||
import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
|
||||
import jakarta.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
class SnowflakeWriter(
|
||||
private val names: TableCatalog,
|
||||
private val catalog: DestinationCatalog,
|
||||
private val stateGatherer: DatabaseInitialStatusGatherer<DirectLoadInitialStatus>,
|
||||
private val streamStateStore: StreamStateStore<DirectLoadTableExecutionConfig>,
|
||||
private val snowflakeClient: SnowflakeAirbyteClient,
|
||||
private val tempTableNameGenerator: TempTableNameGenerator,
|
||||
private val snowflakeConfiguration: SnowflakeConfiguration,
|
||||
private val tempTableNameGenerator: TempTableNameGenerator,
|
||||
) : DestinationWriter {
|
||||
private lateinit var initialStatuses: Map<DestinationStream, DirectLoadInitialStatus>
|
||||
|
||||
override suspend fun setup() {
|
||||
names.values
|
||||
.map { (tableNames, _) -> tableNames.finalTableName!!.namespace }
|
||||
catalog.streams
|
||||
.map { it.tableSchema.tableNames.finalTableName!!.namespace }
|
||||
.toSet()
|
||||
.forEach { snowflakeClient.createNamespace(it) }
|
||||
|
||||
@@ -45,15 +46,15 @@ class SnowflakeWriter(
|
||||
escapeJsonIdentifier(snowflakeConfiguration.internalTableSchema)
|
||||
)
|
||||
|
||||
initialStatuses = stateGatherer.gatherInitialStatus(names)
|
||||
initialStatuses = stateGatherer.gatherInitialStatus()
|
||||
}
|
||||
|
||||
override fun createStreamLoader(stream: DestinationStream): StreamLoader {
|
||||
val initialStatus = initialStatuses[stream]!!
|
||||
val tableNameInfo = names[stream]!!
|
||||
val realTableName = tableNameInfo.tableNames.finalTableName!!
|
||||
val tempTableName = tempTableNameGenerator.generate(realTableName)
|
||||
val columnNameMapping = tableNameInfo.columnNameMapping
|
||||
val realTableName = stream.tableSchema.tableNames.finalTableName!!
|
||||
val tempTableName = stream.tableSchema.tableNames.tempTableName!!
|
||||
val columnNameMapping =
|
||||
ColumnNameMapping(stream.tableSchema.columnSchema.inputToFinalColumnNames)
|
||||
return when (stream.minimumGenerationId) {
|
||||
0L ->
|
||||
when (stream.importType) {
|
||||
|
||||
@@ -9,11 +9,12 @@ import de.siegmar.fastcsv.writer.CsvWriter
|
||||
import de.siegmar.fastcsv.writer.LineDelimiter
|
||||
import de.siegmar.fastcsv.writer.QuoteStrategies
|
||||
import io.airbyte.cdk.load.data.AirbyteValue
|
||||
import io.airbyte.cdk.load.table.TableName
|
||||
import io.airbyte.cdk.load.schema.model.ColumnSchema
|
||||
import io.airbyte.cdk.load.schema.model.TableName
|
||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.sql.QUOTE
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
||||
import io.github.oshai.kotlinlogging.KotlinLogging
|
||||
import java.io.File
|
||||
import java.io.OutputStream
|
||||
@@ -36,10 +37,11 @@ private const val CSV_WRITER_BUFFER_SIZE = 1024 * 1024 // 1 MB
|
||||
|
||||
class SnowflakeInsertBuffer(
|
||||
private val tableName: TableName,
|
||||
val columns: LinkedHashMap<String, String>,
|
||||
private val snowflakeClient: SnowflakeAirbyteClient,
|
||||
val snowflakeConfiguration: SnowflakeConfiguration,
|
||||
val snowflakeColumnUtils: SnowflakeColumnUtils,
|
||||
val columnSchema: ColumnSchema,
|
||||
private val columnManager: SnowflakeColumnManager,
|
||||
private val snowflakeRecordFormatter: SnowflakeRecordFormatter,
|
||||
private val flushLimit: Int = DEFAULT_FLUSH_LIMIT,
|
||||
) {
|
||||
|
||||
@@ -57,12 +59,6 @@ class SnowflakeInsertBuffer(
|
||||
.lineDelimiter(CSV_LINE_DELIMITER)
|
||||
.quoteStrategy(QuoteStrategies.REQUIRED)
|
||||
|
||||
private val snowflakeRecordFormatter: SnowflakeRecordFormatter =
|
||||
when (snowflakeConfiguration.legacyRawTablesOnly) {
|
||||
true -> SnowflakeRawRecordFormatter(columns, snowflakeColumnUtils)
|
||||
else -> SnowflakeSchemaRecordFormatter(columns, snowflakeColumnUtils)
|
||||
}
|
||||
|
||||
fun accumulate(recordFields: Map<String, AirbyteValue>) {
|
||||
if (csvFilePath == null) {
|
||||
val csvFile = createCsvFile()
|
||||
@@ -92,7 +88,9 @@ class SnowflakeInsertBuffer(
|
||||
"Copying staging data into ${tableName.toPrettyString(quote = QUOTE)}..."
|
||||
}
|
||||
// Finally, copy the data from the staging table to the final table
|
||||
snowflakeClient.copyFromStage(tableName, filePath.fileName.toString())
|
||||
// Pass column names to ensure correct mapping even after ALTER TABLE operations
|
||||
val columnNames = columnManager.getTableColumnNames(columnSchema)
|
||||
snowflakeClient.copyFromStage(tableName, filePath.fileName.toString(), columnNames)
|
||||
logger.info {
|
||||
"Finished insert of $recordCount row(s) into ${tableName.toPrettyString(quote = QUOTE)}."
|
||||
}
|
||||
@@ -117,7 +115,9 @@ class SnowflakeInsertBuffer(
|
||||
|
||||
private fun writeToCsvFile(record: Map<String, AirbyteValue>) {
|
||||
csvWriter?.let {
|
||||
it.writeRecord(snowflakeRecordFormatter.format(record).map { col -> col.toString() })
|
||||
it.writeRecord(
|
||||
snowflakeRecordFormatter.format(record, columnSchema).map { col -> col.toString() }
|
||||
)
|
||||
recordCount++
|
||||
if ((recordCount % flushLimit) == 0) {
|
||||
it.flush()
|
||||
|
||||
@@ -8,99 +8,62 @@ import io.airbyte.cdk.load.data.AirbyteValue
|
||||
import io.airbyte.cdk.load.data.NullValue
|
||||
import io.airbyte.cdk.load.data.StringValue
|
||||
import io.airbyte.cdk.load.data.csv.toCsvValue
|
||||
import io.airbyte.cdk.load.message.Meta
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
|
||||
import io.airbyte.cdk.load.schema.model.ColumnSchema
|
||||
import io.airbyte.cdk.load.util.Jsons
|
||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
||||
|
||||
interface SnowflakeRecordFormatter {
|
||||
fun format(record: Map<String, AirbyteValue>): List<Any>
|
||||
fun format(record: Map<String, AirbyteValue>, columnSchema: ColumnSchema): List<Any>
|
||||
}
|
||||
|
||||
class SnowflakeSchemaRecordFormatter(
|
||||
private val columns: LinkedHashMap<String, String>,
|
||||
val snowflakeColumnUtils: SnowflakeColumnUtils,
|
||||
) : SnowflakeRecordFormatter {
|
||||
class SnowflakeSchemaRecordFormatter : SnowflakeRecordFormatter {
|
||||
override fun format(record: Map<String, AirbyteValue>, columnSchema: ColumnSchema): List<Any> {
|
||||
val result = mutableListOf<Any>()
|
||||
val userColumns = columnSchema.finalSchema.keys
|
||||
|
||||
private val airbyteColumnNames =
|
||||
snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet()
|
||||
// WARNING: MUST match the order defined in SnowflakeColumnManager#getTableColumnNames
|
||||
//
|
||||
// Why don't we just use that here? Well, unlike the user fields, the meta fields on the
|
||||
// record are not munged for the destination. So we must access the values for those columns
|
||||
// using the original lowercase meta key.
|
||||
result.add(record[COLUMN_NAME_AB_RAW_ID].toCsvValue())
|
||||
result.add(record[COLUMN_NAME_AB_EXTRACTED_AT].toCsvValue())
|
||||
result.add(record[COLUMN_NAME_AB_META].toCsvValue())
|
||||
result.add(record[COLUMN_NAME_AB_GENERATION_ID].toCsvValue())
|
||||
|
||||
override fun format(record: Map<String, AirbyteValue>): List<Any> =
|
||||
columns.map { (columnName, _) ->
|
||||
/*
|
||||
* Meta columns are forced to uppercase for backwards compatibility with previous
|
||||
* versions of the destination. Therefore, convert the column to lowercase so
|
||||
* that it can match the constants, which use the lowercase version of the meta
|
||||
* column names.
|
||||
*/
|
||||
if (airbyteColumnNames.contains(columnName)) {
|
||||
record[columnName.lowercase()].toCsvValue()
|
||||
} else {
|
||||
record.keys
|
||||
// The columns retrieved from Snowflake do not have any escaping applied.
|
||||
// Therefore, re-apply the compatible name escaping to the name of the
|
||||
// columns retrieved from Snowflake. The record keys should already have
|
||||
// been escaped by the CDK before arriving at the aggregate, so no need
|
||||
// to escape again here.
|
||||
.find { it == columnName.toSnowflakeCompatibleName() }
|
||||
?.let { record[it].toCsvValue() }
|
||||
?: ""
|
||||
}
|
||||
}
|
||||
// Add user columns from the final schema
|
||||
userColumns.forEach { columnName -> result.add(record[columnName].toCsvValue()) }
|
||||
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
class SnowflakeRawRecordFormatter(
|
||||
columns: LinkedHashMap<String, String>,
|
||||
val snowflakeColumnUtils: SnowflakeColumnUtils,
|
||||
) : SnowflakeRecordFormatter {
|
||||
private val columns = columns.keys
|
||||
class SnowflakeRawRecordFormatter : SnowflakeRecordFormatter {
|
||||
|
||||
private val airbyteColumnNames =
|
||||
snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet()
|
||||
|
||||
override fun format(record: Map<String, AirbyteValue>): List<Any> =
|
||||
override fun format(record: Map<String, AirbyteValue>, columnSchema: ColumnSchema): List<Any> =
|
||||
toOutputRecord(record.toMutableMap())
|
||||
|
||||
private fun toOutputRecord(record: MutableMap<String, AirbyteValue>): List<Any> {
|
||||
val outputRecord = mutableListOf<Any>()
|
||||
// Copy the Airbyte metadata columns to the raw output, removing each
|
||||
// one from the record to avoid duplicates in the "data" field
|
||||
columns
|
||||
.filter { airbyteColumnNames.contains(it) && it != Meta.COLUMN_NAME_DATA }
|
||||
.forEach { column -> safeAddToOutput(column, record, outputRecord) }
|
||||
val mutableRecord = record.toMutableMap()
|
||||
|
||||
// Add meta columns in order (except _airbyte_data which we handle specially)
|
||||
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_RAW_ID)?.toCsvValue() ?: "")
|
||||
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_EXTRACTED_AT)?.toCsvValue() ?: "")
|
||||
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_META)?.toCsvValue() ?: "")
|
||||
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_GENERATION_ID)?.toCsvValue() ?: "")
|
||||
outputRecord.add(mutableRecord.remove(COLUMN_NAME_AB_LOADED_AT)?.toCsvValue() ?: "")
|
||||
|
||||
// Do not output null values in the JSON raw output
|
||||
val filteredRecord = record.filter { (_, v) -> v !is NullValue }
|
||||
// Convert all the remaining columns in the record to a JSON document stored in the "data"
|
||||
// column. Add it in the same position as the _airbyte_data column in the column list to
|
||||
// ensure it is inserted into the proper column in the table.
|
||||
insert(
|
||||
columns.indexOf(Meta.COLUMN_NAME_DATA),
|
||||
StringValue(Jsons.writeValueAsString(filteredRecord)).toCsvValue(),
|
||||
outputRecord
|
||||
)
|
||||
val filteredRecord = mutableRecord.filter { (_, v) -> v !is NullValue }
|
||||
|
||||
// Convert all the remaining columns to a JSON document stored in the "data" column
|
||||
outputRecord.add(StringValue(Jsons.writeValueAsString(filteredRecord)).toCsvValue())
|
||||
|
||||
return outputRecord
|
||||
}
|
||||
|
||||
private fun safeAddToOutput(
|
||||
key: String,
|
||||
record: MutableMap<String, AirbyteValue>,
|
||||
output: MutableList<Any>
|
||||
) {
|
||||
val extractedValue = record.remove(key)
|
||||
// Ensure that the data is inserted into the list at the same position as the column
|
||||
insert(columns.indexOf(key), extractedValue?.toCsvValue() ?: "", output)
|
||||
}
|
||||
|
||||
private fun insert(index: Int, value: Any, list: MutableList<Any>) {
|
||||
/*
|
||||
* Attempt to insert the value into the proper order in the list. If the index
|
||||
* is already present in the list, use the add(index, element) method to insert it
|
||||
* into the proper order and push everything to the right. If the index is at the
|
||||
* end of the list, just use add(element) to insert it at the end. If the index
|
||||
* is further beyond the end of the list, throw an exception as that should not occur.
|
||||
*/
|
||||
if (index < list.size) list.add(index, value)
|
||||
else if (index == list.size || index == list.size + 1) list.add(value)
|
||||
else throw IndexOutOfBoundsException()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.write.transform
|
||||
|
||||
import io.airbyte.cdk.load.command.DestinationStream
|
||||
import io.airbyte.cdk.load.dataflow.transform.ColumnNameMapper
|
||||
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import jakarta.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
class SnowflakeColumnNameMapper(
|
||||
private val catalogInfo: TableCatalog,
|
||||
private val snowflakeConfiguration: SnowflakeConfiguration,
|
||||
) : ColumnNameMapper {
|
||||
override fun getMappedColumnName(stream: DestinationStream, columnName: String): String {
|
||||
if (snowflakeConfiguration.legacyRawTablesOnly == true) {
|
||||
return columnName
|
||||
} else {
|
||||
return catalogInfo.getMappedColumnName(stream, columnName)!!
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7,9 +7,11 @@ package io.airbyte.integrations.destination.snowflake.component
|
||||
import io.airbyte.cdk.load.component.TableOperationsFixtures
|
||||
import io.airbyte.cdk.load.component.TableOperationsSuite
|
||||
import io.airbyte.cdk.load.message.Meta
|
||||
import io.airbyte.cdk.load.schema.TableSchemaFactory
|
||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.idTestWithCdcMapping
|
||||
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.testMapping
|
||||
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.idTestWithCdcMapping
|
||||
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.testMapping
|
||||
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeTestTableOperationsClient
|
||||
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.junit.jupiter.api.parallel.Execution
|
||||
@@ -20,6 +22,7 @@ import org.junit.jupiter.api.parallel.ExecutionMode
|
||||
class SnowflakeTableOperationsTest(
|
||||
override val client: SnowflakeAirbyteClient,
|
||||
override val testClient: SnowflakeTestTableOperationsClient,
|
||||
override val schemaFactory: TableSchemaFactory,
|
||||
) : TableOperationsSuite {
|
||||
override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() }
|
||||
|
||||
|
||||
@@ -9,13 +9,15 @@ import io.airbyte.cdk.load.component.TableSchemaEvolutionFixtures
|
||||
import io.airbyte.cdk.load.component.TableSchemaEvolutionSuite
|
||||
import io.airbyte.cdk.load.data.StringValue
|
||||
import io.airbyte.cdk.load.message.Meta
|
||||
import io.airbyte.cdk.load.schema.TableSchemaFactory
|
||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
||||
import io.airbyte.cdk.load.util.serializeToString
|
||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.allTypesColumnNameMapping
|
||||
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.allTypesTableSchema
|
||||
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.idAndTestMapping
|
||||
import io.airbyte.integrations.destination.snowflake.component.SnowflakeComponentTestFixtures.testMapping
|
||||
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.allTypesColumnNameMapping
|
||||
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.allTypesTableSchema
|
||||
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.idAndTestMapping
|
||||
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeComponentTestFixtures.testMapping
|
||||
import io.airbyte.integrations.destination.snowflake.component.config.SnowflakeTestTableOperationsClient
|
||||
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.junit.jupiter.api.parallel.Execution
|
||||
@@ -27,6 +29,7 @@ class SnowflakeTableSchemaEvolutionTest(
|
||||
override val client: SnowflakeAirbyteClient,
|
||||
override val opsClient: SnowflakeAirbyteClient,
|
||||
override val testClient: SnowflakeTestTableOperationsClient,
|
||||
override val schemaFactory: TableSchemaFactory,
|
||||
) : TableSchemaEvolutionSuite {
|
||||
override val airbyteMetaColumnMapping = Meta.COLUMN_NAMES.associateWith { it.uppercase() }
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.component
|
||||
package io.airbyte.integrations.destination.snowflake.component.config
|
||||
|
||||
import io.airbyte.cdk.load.component.ColumnType
|
||||
import io.airbyte.cdk.load.component.TableOperationsFixtures
|
||||
@@ -33,6 +33,8 @@ object SnowflakeComponentTestFixtures {
|
||||
"TIME_NTZ" to ColumnType("TIME", true),
|
||||
"ARRAY" to ColumnType("ARRAY", true),
|
||||
"OBJECT" to ColumnType("OBJECT", true),
|
||||
"UNION" to ColumnType("VARIANT", true),
|
||||
"LEGACY_UNION" to ColumnType("VARIANT", true),
|
||||
"UNKNOWN" to ColumnType("VARIANT", true),
|
||||
)
|
||||
)
|
||||
@@ -2,7 +2,7 @@
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.component
|
||||
package io.airbyte.integrations.destination.snowflake.component.config
|
||||
|
||||
import io.airbyte.cdk.load.component.config.TestConfigLoader.loadTestConfig
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
@@ -2,22 +2,24 @@
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.component
|
||||
package io.airbyte.integrations.destination.snowflake.component.config
|
||||
|
||||
import io.airbyte.cdk.load.component.ColumnType
|
||||
import io.airbyte.cdk.load.component.TestTableOperationsClient
|
||||
import io.airbyte.cdk.load.data.AirbyteValue
|
||||
import io.airbyte.cdk.load.dataflow.state.PartitionKey
|
||||
import io.airbyte.cdk.load.dataflow.transform.RecordDTO
|
||||
import io.airbyte.cdk.load.table.TableName
|
||||
import io.airbyte.cdk.load.schema.model.TableName
|
||||
import io.airbyte.cdk.load.util.Jsons
|
||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||
import io.airbyte.integrations.destination.snowflake.client.execute
|
||||
import io.airbyte.integrations.destination.snowflake.dataflow.SnowflakeAggregate
|
||||
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
|
||||
import io.airbyte.integrations.destination.snowflake.sql.andLog
|
||||
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
|
||||
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
|
||||
import io.micronaut.context.annotation.Requires
|
||||
import jakarta.inject.Singleton
|
||||
import java.time.format.DateTimeFormatter
|
||||
@@ -29,25 +31,40 @@ import net.snowflake.client.jdbc.SnowflakeTimestampWithTimezone
|
||||
class SnowflakeTestTableOperationsClient(
|
||||
private val client: SnowflakeAirbyteClient,
|
||||
private val dataSource: DataSource,
|
||||
private val snowflakeSqlNameUtils: SnowflakeSqlNameUtils,
|
||||
private val snowflakeColumnUtils: SnowflakeColumnUtils,
|
||||
private val sqlGenerator: SnowflakeDirectLoadSqlGenerator,
|
||||
private val snowflakeConfiguration: SnowflakeConfiguration,
|
||||
private val columnManager: SnowflakeColumnManager,
|
||||
private val snowflakeRecordFormatter: SnowflakeRecordFormatter,
|
||||
) : TestTableOperationsClient {
|
||||
override suspend fun dropNamespace(namespace: String) {
|
||||
dataSource.execute(
|
||||
"DROP SCHEMA IF EXISTS ${snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)}".andLog()
|
||||
"DROP SCHEMA IF EXISTS ${sqlGenerator.fullyQualifiedNamespace(namespace)}".andLog()
|
||||
)
|
||||
}
|
||||
|
||||
override suspend fun insertRecords(table: TableName, records: List<Map<String, AirbyteValue>>) {
|
||||
// TODO: we should just pass a proper column schema
|
||||
// Since we don't pass in a proper column schema, we have to recreate one here
|
||||
// Fetch the columns and filter out the meta columns so we're just looking at user columns
|
||||
val columnTypes =
|
||||
client.describeTable(table).filterNot {
|
||||
columnManager.getMetaColumnNames().contains(it.key)
|
||||
}
|
||||
val columnSchema =
|
||||
io.airbyte.cdk.load.schema.model.ColumnSchema(
|
||||
inputToFinalColumnNames = columnTypes.keys.associateWith { it },
|
||||
finalSchema = columnTypes.mapValues { (_, _) -> ColumnType("", true) },
|
||||
inputSchema = emptyMap() // Not needed for insert buffer
|
||||
)
|
||||
val a =
|
||||
SnowflakeAggregate(
|
||||
SnowflakeInsertBuffer(
|
||||
table,
|
||||
client.describeTable(table),
|
||||
client,
|
||||
snowflakeConfiguration,
|
||||
snowflakeColumnUtils,
|
||||
tableName = table,
|
||||
snowflakeClient = client,
|
||||
snowflakeConfiguration = snowflakeConfiguration,
|
||||
columnSchema = columnSchema,
|
||||
columnManager = columnManager,
|
||||
snowflakeRecordFormatter = snowflakeRecordFormatter,
|
||||
)
|
||||
)
|
||||
records.forEach { a.accept(RecordDTO(it, PartitionKey(""), 0, 0)) }
|
||||
@@ -7,11 +7,11 @@ package io.airbyte.integrations.destination.snowflake.write
|
||||
import io.airbyte.cdk.load.test.util.DestinationCleaner
|
||||
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
|
||||
import io.airbyte.integrations.destination.snowflake.cdk.SnowflakeMigratingConfigurationSpecificationSupplier
|
||||
import io.airbyte.integrations.destination.snowflake.db.escapeJsonIdentifier
|
||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfigurationFactory
|
||||
import io.airbyte.integrations.destination.snowflake.sql.STAGE_NAME_PREFIX
|
||||
import io.airbyte.integrations.destination.snowflake.sql.escapeJsonIdentifier
|
||||
import io.airbyte.integrations.destination.snowflake.sql.quote
|
||||
import java.nio.file.Files
|
||||
import java.sql.Connection
|
||||
|
||||
@@ -12,16 +12,22 @@ import io.airbyte.cdk.load.data.ObjectValue
|
||||
import io.airbyte.cdk.load.data.json.toAirbyteValue
|
||||
import io.airbyte.cdk.load.message.Meta
|
||||
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
|
||||
import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator
|
||||
import io.airbyte.cdk.load.test.util.DestinationDataDumper
|
||||
import io.airbyte.cdk.load.test.util.OutputRecord
|
||||
import io.airbyte.cdk.load.util.UUIDGenerator
|
||||
import io.airbyte.cdk.load.util.deserializeToNode
|
||||
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
|
||||
import io.airbyte.integrations.destination.snowflake.db.SnowflakeFinalTableNameGenerator
|
||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeTableSchemaMapper
|
||||
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
|
||||
import io.airbyte.integrations.destination.snowflake.sql.sqlEscape
|
||||
import java.math.BigDecimal
|
||||
import java.sql.Date
|
||||
import java.sql.Time
|
||||
import java.sql.Timestamp
|
||||
import net.snowflake.client.jdbc.SnowflakeTimestampWithTimezone
|
||||
|
||||
private val AIRBYTE_META_COLUMNS = Meta.COLUMN_NAMES + setOf(CDC_DELETED_AT_COLUMN)
|
||||
@@ -34,8 +40,14 @@ class SnowflakeDataDumper(
|
||||
stream: DestinationStream
|
||||
): List<OutputRecord> {
|
||||
val config = configProvider(spec)
|
||||
val sqlUtils = SnowflakeSqlNameUtils(config)
|
||||
val snowflakeFinalTableNameGenerator = SnowflakeFinalTableNameGenerator(config)
|
||||
val snowflakeFinalTableNameGenerator =
|
||||
SnowflakeTableSchemaMapper(
|
||||
config = config,
|
||||
tempTableNameGenerator = DefaultTempTableNameGenerator(),
|
||||
)
|
||||
val snowflakeColumnManager = SnowflakeColumnManager(config)
|
||||
val sqlGenerator =
|
||||
SnowflakeDirectLoadSqlGenerator(UUIDGenerator(), config, snowflakeColumnManager)
|
||||
val dataSource =
|
||||
SnowflakeBeanFactory()
|
||||
.snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY")
|
||||
@@ -46,7 +58,7 @@ class SnowflakeDataDumper(
|
||||
ds.connection.use { connection ->
|
||||
val statement = connection.createStatement()
|
||||
val tableName =
|
||||
snowflakeFinalTableNameGenerator.getTableName(stream.mappedDescriptor)
|
||||
snowflakeFinalTableNameGenerator.toFinalTableName(stream.mappedDescriptor)
|
||||
|
||||
// First check if the table exists
|
||||
val tableExistsQuery =
|
||||
@@ -69,7 +81,7 @@ class SnowflakeDataDumper(
|
||||
|
||||
val resultSet =
|
||||
statement.executeQuery(
|
||||
"SELECT * FROM ${sqlUtils.fullyQualifiedName(tableName)}"
|
||||
"SELECT * FROM ${sqlGenerator.fullyQualifiedName(tableName)}"
|
||||
)
|
||||
|
||||
while (resultSet.next()) {
|
||||
@@ -143,10 +155,10 @@ class SnowflakeDataDumper(
|
||||
private fun convertValue(value: Any?): Any? =
|
||||
when (value) {
|
||||
is BigDecimal -> value.toBigInteger()
|
||||
is java.sql.Date -> value.toLocalDate()
|
||||
is Date -> value.toLocalDate()
|
||||
is SnowflakeTimestampWithTimezone -> value.toZonedDateTime()
|
||||
is java.sql.Time -> value.toLocalTime()
|
||||
is java.sql.Timestamp -> value.toLocalDateTime()
|
||||
is Time -> value.toLocalTime()
|
||||
is Timestamp -> value.toLocalDateTime()
|
||||
else -> value
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import io.airbyte.cdk.load.dataflow.transform.ValidationResult
|
||||
import io.airbyte.cdk.load.message.Meta
|
||||
import io.airbyte.cdk.load.test.util.ExpectedRecordMapper
|
||||
import io.airbyte.cdk.load.test.util.OutputRecord
|
||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.write.transform.SnowflakeValueCoercer
|
||||
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange
|
||||
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Change
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
package io.airbyte.integrations.destination.snowflake.write
|
||||
|
||||
import io.airbyte.cdk.load.test.util.NameMapper
|
||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||
|
||||
class SnowflakeNameMapper : NameMapper {
|
||||
override fun mapFieldName(path: List<String>): List<String> =
|
||||
|
||||
@@ -9,13 +9,16 @@ import io.airbyte.cdk.load.command.DestinationStream
|
||||
import io.airbyte.cdk.load.data.StringValue
|
||||
import io.airbyte.cdk.load.data.json.toAirbyteValue
|
||||
import io.airbyte.cdk.load.message.Meta
|
||||
import io.airbyte.cdk.load.table.DefaultTempTableNameGenerator
|
||||
import io.airbyte.cdk.load.test.util.DestinationDataDumper
|
||||
import io.airbyte.cdk.load.test.util.OutputRecord
|
||||
import io.airbyte.cdk.load.util.UUIDGenerator
|
||||
import io.airbyte.cdk.load.util.deserializeToNode
|
||||
import io.airbyte.integrations.destination.snowflake.SnowflakeBeanFactory
|
||||
import io.airbyte.integrations.destination.snowflake.db.SnowflakeFinalTableNameGenerator
|
||||
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeTableSchemaMapper
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeSqlNameUtils
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
|
||||
|
||||
class SnowflakeRawDataDumper(
|
||||
private val configProvider: (ConfigurationSpecification) -> SnowflakeConfiguration
|
||||
@@ -27,8 +30,18 @@ class SnowflakeRawDataDumper(
|
||||
val output = mutableListOf<OutputRecord>()
|
||||
|
||||
val config = configProvider(spec)
|
||||
val sqlUtils = SnowflakeSqlNameUtils(config)
|
||||
val snowflakeFinalTableNameGenerator = SnowflakeFinalTableNameGenerator(config)
|
||||
val snowflakeColumnManager = SnowflakeColumnManager(config)
|
||||
val sqlGenerator =
|
||||
SnowflakeDirectLoadSqlGenerator(
|
||||
UUIDGenerator(),
|
||||
config,
|
||||
snowflakeColumnManager,
|
||||
)
|
||||
val snowflakeFinalTableNameGenerator =
|
||||
SnowflakeTableSchemaMapper(
|
||||
config = config,
|
||||
tempTableNameGenerator = DefaultTempTableNameGenerator(),
|
||||
)
|
||||
val dataSource =
|
||||
SnowflakeBeanFactory()
|
||||
.snowflakeDataSource(snowflakeConfiguration = config, airbyteEdition = "COMMUNITY")
|
||||
@@ -37,11 +50,11 @@ class SnowflakeRawDataDumper(
|
||||
ds.connection.use { connection ->
|
||||
val statement = connection.createStatement()
|
||||
val tableName =
|
||||
snowflakeFinalTableNameGenerator.getTableName(stream.mappedDescriptor)
|
||||
snowflakeFinalTableNameGenerator.toFinalTableName(stream.mappedDescriptor)
|
||||
|
||||
val resultSet =
|
||||
statement.executeQuery(
|
||||
"SELECT * FROM ${sqlUtils.fullyQualifiedName(tableName)}"
|
||||
"SELECT * FROM ${sqlGenerator.fullyQualifiedName(tableName)}"
|
||||
)
|
||||
|
||||
while (resultSet.next()) {
|
||||
|
||||
@@ -6,7 +6,7 @@ package io.airbyte.integrations.destination.snowflake
|
||||
|
||||
import com.zaxxer.hikari.HikariConfig
|
||||
import com.zaxxer.hikari.HikariDataSource
|
||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode
|
||||
import io.airbyte.integrations.destination.snowflake.spec.KeyPairAuthConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
|
||||
@@ -5,11 +5,9 @@
|
||||
package io.airbyte.integrations.destination.snowflake.check
|
||||
|
||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.sql.DEFAULT_COLUMNS
|
||||
import io.airbyte.integrations.destination.snowflake.sql.RAW_DATA_COLUMN
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
||||
import io.mockk.coEvery
|
||||
import io.mockk.coVerify
|
||||
import io.mockk.every
|
||||
@@ -23,47 +21,31 @@ internal class SnowflakeCheckerTest {
|
||||
@ParameterizedTest
|
||||
@ValueSource(booleans = [true, false])
|
||||
fun testSuccessfulCheck(isLegacyRawTablesOnly: Boolean) {
|
||||
val defaultColumnsMap =
|
||||
if (isLegacyRawTablesOnly) {
|
||||
linkedMapOf<String, String>().also { map ->
|
||||
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
|
||||
map[it.columnName] = it.columnType
|
||||
}
|
||||
}
|
||||
} else {
|
||||
linkedMapOf<String, String>().also { map ->
|
||||
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
|
||||
map[it.columnName.toSnowflakeCompatibleName()] = it.columnType
|
||||
}
|
||||
}
|
||||
}
|
||||
val defaultColumns = defaultColumnsMap.keys.toMutableList()
|
||||
val snowflakeAirbyteClient: SnowflakeAirbyteClient =
|
||||
mockk(relaxed = true) {
|
||||
coEvery { countTable(any()) } returns 1L
|
||||
coEvery { describeTable(any()) } returns defaultColumnsMap
|
||||
}
|
||||
mockk(relaxed = true) { coEvery { countTable(any()) } returns 1L }
|
||||
|
||||
val testSchema = "test-schema"
|
||||
val snowflakeConfiguration: SnowflakeConfiguration = mockk {
|
||||
every { schema } returns testSchema
|
||||
every { legacyRawTablesOnly } returns isLegacyRawTablesOnly
|
||||
}
|
||||
val snowflakeColumnUtils =
|
||||
mockk<SnowflakeColumnUtils>(relaxUnitFun = true) {
|
||||
every { getFormattedDefaultColumnNames(any()) } returns defaultColumns
|
||||
}
|
||||
|
||||
val columnManager: SnowflakeColumnManager = SnowflakeColumnManager(snowflakeConfiguration)
|
||||
|
||||
val checker =
|
||||
SnowflakeChecker(
|
||||
snowflakeAirbyteClient = snowflakeAirbyteClient,
|
||||
snowflakeConfiguration = snowflakeConfiguration,
|
||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
||||
columnManager = columnManager,
|
||||
)
|
||||
checker.check()
|
||||
|
||||
coVerify(exactly = 1) {
|
||||
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
|
||||
if (isLegacyRawTablesOnly) {
|
||||
snowflakeAirbyteClient.createNamespace(testSchema)
|
||||
} else {
|
||||
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
|
||||
}
|
||||
}
|
||||
coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) }
|
||||
coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) }
|
||||
@@ -72,48 +54,32 @@ internal class SnowflakeCheckerTest {
|
||||
@ParameterizedTest
|
||||
@ValueSource(booleans = [true, false])
|
||||
fun testUnsuccessfulCheck(isLegacyRawTablesOnly: Boolean) {
|
||||
val defaultColumnsMap =
|
||||
if (isLegacyRawTablesOnly) {
|
||||
linkedMapOf<String, String>().also { map ->
|
||||
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
|
||||
map[it.columnName] = it.columnType
|
||||
}
|
||||
}
|
||||
} else {
|
||||
linkedMapOf<String, String>().also { map ->
|
||||
(DEFAULT_COLUMNS + RAW_DATA_COLUMN).forEach {
|
||||
map[it.columnName.toSnowflakeCompatibleName()] = it.columnType
|
||||
}
|
||||
}
|
||||
}
|
||||
val defaultColumns = defaultColumnsMap.keys.toMutableList()
|
||||
val snowflakeAirbyteClient: SnowflakeAirbyteClient =
|
||||
mockk(relaxed = true) {
|
||||
coEvery { countTable(any()) } returns 0L
|
||||
coEvery { describeTable(any()) } returns defaultColumnsMap
|
||||
}
|
||||
mockk(relaxed = true) { coEvery { countTable(any()) } returns 0L }
|
||||
|
||||
val testSchema = "test-schema"
|
||||
val snowflakeConfiguration: SnowflakeConfiguration = mockk {
|
||||
every { schema } returns testSchema
|
||||
every { legacyRawTablesOnly } returns isLegacyRawTablesOnly
|
||||
}
|
||||
val snowflakeColumnUtils =
|
||||
mockk<SnowflakeColumnUtils>(relaxUnitFun = true) {
|
||||
every { getFormattedDefaultColumnNames(any()) } returns defaultColumns
|
||||
}
|
||||
|
||||
val columnManager: SnowflakeColumnManager = SnowflakeColumnManager(snowflakeConfiguration)
|
||||
|
||||
val checker =
|
||||
SnowflakeChecker(
|
||||
snowflakeAirbyteClient = snowflakeAirbyteClient,
|
||||
snowflakeConfiguration = snowflakeConfiguration,
|
||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
||||
columnManager = columnManager,
|
||||
)
|
||||
|
||||
assertThrows<IllegalArgumentException> { checker.check() }
|
||||
|
||||
coVerify(exactly = 1) {
|
||||
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
|
||||
if (isLegacyRawTablesOnly) {
|
||||
snowflakeAirbyteClient.createNamespace(testSchema)
|
||||
} else {
|
||||
snowflakeAirbyteClient.createNamespace(testSchema.toSnowflakeCompatibleName())
|
||||
}
|
||||
}
|
||||
coVerify(exactly = 1) { snowflakeAirbyteClient.createTable(any(), any(), any(), any()) }
|
||||
coVerify(exactly = 1) { snowflakeAirbyteClient.dropTable(any()) }
|
||||
|
||||
@@ -6,24 +6,16 @@ package io.airbyte.integrations.destination.snowflake.client
|
||||
|
||||
import io.airbyte.cdk.ConfigErrorException
|
||||
import io.airbyte.cdk.load.command.DestinationStream
|
||||
import io.airbyte.cdk.load.command.NamespaceMapper
|
||||
import io.airbyte.cdk.load.command.Overwrite
|
||||
import io.airbyte.cdk.load.component.ColumnType
|
||||
import io.airbyte.cdk.load.config.NamespaceDefinitionType
|
||||
import io.airbyte.cdk.load.data.AirbyteType
|
||||
import io.airbyte.cdk.load.data.FieldType
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
|
||||
import io.airbyte.cdk.load.schema.model.TableName
|
||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
||||
import io.airbyte.cdk.load.table.TableName
|
||||
import io.airbyte.integrations.destination.snowflake.db.ColumnDefinition
|
||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.sql.COUNT_TOTAL_ALIAS
|
||||
import io.airbyte.integrations.destination.snowflake.sql.ColumnAndType
|
||||
import io.airbyte.integrations.destination.snowflake.sql.DEFAULT_COLUMNS
|
||||
import io.airbyte.integrations.destination.snowflake.sql.QUOTE
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeDirectLoadSqlGenerator
|
||||
import io.mockk.Runs
|
||||
import io.mockk.every
|
||||
@@ -49,31 +41,18 @@ internal class SnowflakeAirbyteClientTest {
|
||||
private lateinit var client: SnowflakeAirbyteClient
|
||||
private lateinit var dataSource: DataSource
|
||||
private lateinit var sqlGenerator: SnowflakeDirectLoadSqlGenerator
|
||||
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
|
||||
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
|
||||
private lateinit var columnManager: SnowflakeColumnManager
|
||||
|
||||
@BeforeEach
|
||||
fun setup() {
|
||||
dataSource = mockk()
|
||||
sqlGenerator = mockk(relaxed = true)
|
||||
snowflakeColumnUtils =
|
||||
mockk(relaxed = true) {
|
||||
every { formatColumnName(any()) } answers
|
||||
{
|
||||
firstArg<String>().toSnowflakeCompatibleName()
|
||||
}
|
||||
every { getFormattedDefaultColumnNames(any()) } returns
|
||||
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() }
|
||||
}
|
||||
snowflakeConfiguration =
|
||||
mockk(relaxed = true) { every { database } returns "test_database" }
|
||||
columnManager = mockk(relaxed = true)
|
||||
client =
|
||||
SnowflakeAirbyteClient(
|
||||
dataSource,
|
||||
sqlGenerator,
|
||||
snowflakeColumnUtils,
|
||||
snowflakeConfiguration
|
||||
)
|
||||
SnowflakeAirbyteClient(dataSource, sqlGenerator, snowflakeConfiguration, columnManager)
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -231,7 +210,7 @@ internal class SnowflakeAirbyteClientTest {
|
||||
@Test
|
||||
fun testCreateTable() {
|
||||
val columnNameMapping = mockk<ColumnNameMapping>(relaxed = true)
|
||||
val stream = mockk<DestinationStream>()
|
||||
val stream = mockk<DestinationStream>(relaxed = true)
|
||||
val tableName = TableName(namespace = "namespace", name = "name")
|
||||
val resultSet = mockk<ResultSet>(relaxed = true)
|
||||
val statement =
|
||||
@@ -254,9 +233,7 @@ internal class SnowflakeAirbyteClientTest {
|
||||
columnNameMapping = columnNameMapping,
|
||||
replace = true,
|
||||
)
|
||||
verify(exactly = 1) {
|
||||
sqlGenerator.createTable(stream, tableName, columnNameMapping, true)
|
||||
}
|
||||
verify(exactly = 1) { sqlGenerator.createTable(tableName, any(), true) }
|
||||
verify(exactly = 1) { sqlGenerator.createSnowflakeStage(tableName) }
|
||||
verify(exactly = 2) { mockConnection.close() }
|
||||
}
|
||||
@@ -288,7 +265,7 @@ internal class SnowflakeAirbyteClientTest {
|
||||
targetTableName = destinationTableName,
|
||||
)
|
||||
verify(exactly = 1) {
|
||||
sqlGenerator.copyTable(columnNameMapping, sourceTableName, destinationTableName)
|
||||
sqlGenerator.copyTable(any<Set<String>>(), sourceTableName, destinationTableName)
|
||||
}
|
||||
verify(exactly = 1) { mockConnection.close() }
|
||||
}
|
||||
@@ -299,7 +276,7 @@ internal class SnowflakeAirbyteClientTest {
|
||||
val columnNameMapping = mockk<ColumnNameMapping>(relaxed = true)
|
||||
val sourceTableName = TableName(namespace = "namespace", name = "source")
|
||||
val destinationTableName = TableName(namespace = "namespace", name = "destination")
|
||||
val stream = mockk<DestinationStream>()
|
||||
val stream = mockk<DestinationStream>(relaxed = true)
|
||||
val resultSet = mockk<ResultSet>(relaxed = true)
|
||||
val statement =
|
||||
mockk<Statement> {
|
||||
@@ -322,12 +299,7 @@ internal class SnowflakeAirbyteClientTest {
|
||||
targetTableName = destinationTableName,
|
||||
)
|
||||
verify(exactly = 1) {
|
||||
sqlGenerator.upsertTable(
|
||||
stream,
|
||||
columnNameMapping,
|
||||
sourceTableName,
|
||||
destinationTableName
|
||||
)
|
||||
sqlGenerator.upsertTable(any(), sourceTableName, destinationTableName)
|
||||
}
|
||||
verify(exactly = 1) { mockConnection.close() }
|
||||
}
|
||||
@@ -379,7 +351,7 @@ internal class SnowflakeAirbyteClientTest {
|
||||
}
|
||||
|
||||
every { dataSource.connection } returns mockConnection
|
||||
every { snowflakeColumnUtils.getGenerationIdColumnName() } returns generationIdColumnName
|
||||
every { columnManager.getGenerationIdColumnName() } returns generationIdColumnName
|
||||
every { sqlGenerator.getGenerationId(tableName) } returns
|
||||
"SELECT $generationIdColumnName FROM ${tableName.toPrettyString(QUOTE)}"
|
||||
|
||||
@@ -501,8 +473,8 @@ internal class SnowflakeAirbyteClientTest {
|
||||
every { dataSource.connection } returns mockConnection
|
||||
|
||||
runBlocking {
|
||||
client.copyFromStage(tableName, "test.csv.gz")
|
||||
verify(exactly = 1) { sqlGenerator.copyFromStage(tableName, "test.csv.gz") }
|
||||
client.copyFromStage(tableName, "test.csv.gz", listOf())
|
||||
verify(exactly = 1) { sqlGenerator.copyFromStage(tableName, "test.csv.gz", listOf()) }
|
||||
verify(exactly = 1) { mockConnection.close() }
|
||||
}
|
||||
}
|
||||
@@ -556,7 +528,7 @@ internal class SnowflakeAirbyteClientTest {
|
||||
"COL1" andThen
|
||||
COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName() andThen
|
||||
"COL2"
|
||||
every { resultSet.getString("type") } returns "VARCHAR(255)" andThen "NUMBER(38,0)"
|
||||
every { resultSet.getString("type") } returns "VARCHAR(255)" andThen "NUMBER"
|
||||
every { resultSet.getString("null?") } returns "Y" andThen "N" andThen "N"
|
||||
|
||||
val statement =
|
||||
@@ -571,6 +543,10 @@ internal class SnowflakeAirbyteClientTest {
|
||||
|
||||
every { dataSource.connection } returns connection
|
||||
|
||||
// Mock the columnManager to return the correct set of meta columns
|
||||
every { columnManager.getMetaColumnNames() } returns
|
||||
setOf(COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName())
|
||||
|
||||
val result = client.getColumnsFromDb(tableName)
|
||||
|
||||
val expectedColumns =
|
||||
@@ -582,81 +558,6 @@ internal class SnowflakeAirbyteClientTest {
|
||||
assertEquals(expectedColumns, result)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getColumnsFromStream should return correct column definitions`() {
|
||||
val schema = mockk<AirbyteType>()
|
||||
val stream =
|
||||
DestinationStream(
|
||||
unmappedNamespace = "test_namespace",
|
||||
unmappedName = "test_stream",
|
||||
importType = Overwrite,
|
||||
schema = schema,
|
||||
generationId = 1,
|
||||
minimumGenerationId = 1,
|
||||
syncId = 1,
|
||||
namespaceMapper = NamespaceMapper(NamespaceDefinitionType.DESTINATION)
|
||||
)
|
||||
val columnNameMapping =
|
||||
ColumnNameMapping(
|
||||
mapOf(
|
||||
"col1" to "COL1_MAPPED",
|
||||
"col2" to "COL2_MAPPED",
|
||||
)
|
||||
)
|
||||
|
||||
val col1FieldType = mockk<FieldType>()
|
||||
every { col1FieldType.type } returns mockk()
|
||||
|
||||
val col2FieldType = mockk<FieldType>()
|
||||
every { col2FieldType.type } returns mockk()
|
||||
|
||||
every { schema.asColumns() } returns
|
||||
linkedMapOf("col1" to col1FieldType, "col2" to col2FieldType)
|
||||
every { snowflakeColumnUtils.toDialectType(col1FieldType.type) } returns "VARCHAR(255)"
|
||||
every { snowflakeColumnUtils.toDialectType(col2FieldType.type) } returns "NUMBER(38,0)"
|
||||
every { snowflakeColumnUtils.columnsAndTypes(any(), any()) } returns
|
||||
listOf(ColumnAndType("COL1_MAPPED", "VARCHAR"), ColumnAndType("COL2_MAPPED", "NUMBER"))
|
||||
every { snowflakeColumnUtils.formatColumnName(any(), false) } answers
|
||||
{
|
||||
firstArg<String>().toSnowflakeCompatibleName()
|
||||
}
|
||||
|
||||
val result = client.getColumnsFromStream(stream, columnNameMapping)
|
||||
|
||||
val expectedColumns =
|
||||
mapOf(
|
||||
"COL1_MAPPED" to ColumnType("VARCHAR", true),
|
||||
"COL2_MAPPED" to ColumnType("NUMBER", true),
|
||||
)
|
||||
|
||||
assertEquals(expectedColumns, result)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `generateSchemaChanges should correctly identify changes`() {
|
||||
val columnsInDb =
|
||||
setOf(
|
||||
ColumnDefinition("COL1", "VARCHAR"),
|
||||
ColumnDefinition("COL2", "NUMBER"),
|
||||
ColumnDefinition("COL3", "BOOLEAN")
|
||||
)
|
||||
val columnsInStream =
|
||||
setOf(
|
||||
ColumnDefinition("COL1", "VARCHAR"), // Unchanged
|
||||
ColumnDefinition("COL3", "TEXT"), // Modified
|
||||
ColumnDefinition("COL4", "DATE") // Added
|
||||
)
|
||||
|
||||
val (added, deleted, modified) = client.generateSchemaChanges(columnsInDb, columnsInStream)
|
||||
|
||||
assertEquals(1, added.size)
|
||||
assertEquals("COL4", added.first().name)
|
||||
assertEquals(1, deleted.size)
|
||||
assertEquals("COL2", deleted.first().name)
|
||||
assertEquals(1, modified.size)
|
||||
assertEquals("COL3", modified.first().name)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCreateNamespaceWithNetworkFailure() {
|
||||
val namespace = "test_namespace"
|
||||
|
||||
@@ -4,14 +4,19 @@
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.dataflow
|
||||
|
||||
import io.airbyte.cdk.load.command.DestinationCatalog
|
||||
import io.airbyte.cdk.load.command.DestinationStream
|
||||
import io.airbyte.cdk.load.command.DestinationStream.Descriptor
|
||||
import io.airbyte.cdk.load.dataflow.aggregate.StoreKey
|
||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableExecutionConfig
|
||||
import io.airbyte.cdk.load.table.TableName
|
||||
import io.airbyte.cdk.load.schema.model.TableName
|
||||
import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
|
||||
import io.airbyte.cdk.load.write.StreamStateStore
|
||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
||||
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRawRecordFormatter
|
||||
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeRecordFormatter
|
||||
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeSchemaRecordFormatter
|
||||
import io.mockk.every
|
||||
import io.mockk.mockk
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
@@ -22,27 +27,33 @@ internal class SnowflakeAggregateFactoryTest {
|
||||
@Test
|
||||
fun testCreatingAggregateWithRawBuffer() {
|
||||
val descriptor = Descriptor(namespace = "namespace", name = "name")
|
||||
val directLoadTableExecutionConfig =
|
||||
DirectLoadTableExecutionConfig(
|
||||
tableName =
|
||||
TableName(
|
||||
namespace = descriptor.namespace!!,
|
||||
name = descriptor.name,
|
||||
)
|
||||
val tableName =
|
||||
TableName(
|
||||
namespace = descriptor.namespace!!,
|
||||
name = descriptor.name,
|
||||
)
|
||||
val directLoadTableExecutionConfig = DirectLoadTableExecutionConfig(tableName = tableName)
|
||||
val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name)
|
||||
val streamStore = StreamStateStore<DirectLoadTableExecutionConfig>()
|
||||
streamStore.put(descriptor, directLoadTableExecutionConfig)
|
||||
streamStore.put(key, directLoadTableExecutionConfig)
|
||||
|
||||
val stream = mockk<DestinationStream>(relaxed = true)
|
||||
val catalog = mockk<DestinationCatalog> { every { getStream(key) } returns stream }
|
||||
|
||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||
val snowflakeConfiguration =
|
||||
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns true }
|
||||
val snowflakeColumnUtils = mockk<SnowflakeColumnUtils>(relaxed = true)
|
||||
val columnManager = SnowflakeColumnManager(snowflakeConfiguration)
|
||||
val snowflakeRecordFormatter: SnowflakeRecordFormatter = SnowflakeRawRecordFormatter()
|
||||
|
||||
val factory =
|
||||
SnowflakeAggregateFactory(
|
||||
snowflakeClient = snowflakeClient,
|
||||
streamStateStore = streamStore,
|
||||
snowflakeConfiguration = snowflakeConfiguration,
|
||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
||||
catalog = catalog,
|
||||
columnManager = columnManager,
|
||||
snowflakeRecordFormatter = snowflakeRecordFormatter,
|
||||
)
|
||||
val aggregate = factory.create(key)
|
||||
assertNotNull(aggregate)
|
||||
@@ -52,26 +63,33 @@ internal class SnowflakeAggregateFactoryTest {
|
||||
@Test
|
||||
fun testCreatingAggregateWithStagingBuffer() {
|
||||
val descriptor = Descriptor(namespace = "namespace", name = "name")
|
||||
val directLoadTableExecutionConfig =
|
||||
DirectLoadTableExecutionConfig(
|
||||
tableName =
|
||||
TableName(
|
||||
namespace = descriptor.namespace!!,
|
||||
name = descriptor.name,
|
||||
)
|
||||
val tableName =
|
||||
TableName(
|
||||
namespace = descriptor.namespace!!,
|
||||
name = descriptor.name,
|
||||
)
|
||||
val directLoadTableExecutionConfig = DirectLoadTableExecutionConfig(tableName = tableName)
|
||||
val key = StoreKey(namespace = descriptor.namespace!!, name = descriptor.name)
|
||||
val streamStore = StreamStateStore<DirectLoadTableExecutionConfig>()
|
||||
streamStore.put(descriptor, directLoadTableExecutionConfig)
|
||||
streamStore.put(key, directLoadTableExecutionConfig)
|
||||
|
||||
val stream = mockk<DestinationStream>(relaxed = true)
|
||||
val catalog = mockk<DestinationCatalog> { every { getStream(key) } returns stream }
|
||||
|
||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||
val snowflakeConfiguration = mockk<SnowflakeConfiguration>(relaxed = true)
|
||||
val snowflakeColumnUtils = mockk<SnowflakeColumnUtils>(relaxed = true)
|
||||
val snowflakeConfiguration =
|
||||
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns false }
|
||||
val columnManager = SnowflakeColumnManager(snowflakeConfiguration)
|
||||
val snowflakeRecordFormatter: SnowflakeRecordFormatter = SnowflakeSchemaRecordFormatter()
|
||||
|
||||
val factory =
|
||||
SnowflakeAggregateFactory(
|
||||
snowflakeClient = snowflakeClient,
|
||||
streamStateStore = streamStore,
|
||||
snowflakeConfiguration = snowflakeConfiguration,
|
||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
||||
catalog = catalog,
|
||||
columnManager = columnManager,
|
||||
snowflakeRecordFormatter = snowflakeRecordFormatter,
|
||||
)
|
||||
val aggregate = factory.create(key)
|
||||
assertNotNull(aggregate)
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.db
|
||||
|
||||
import io.mockk.every
|
||||
import io.mockk.mockk
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
internal class SnowflakeColumnNameGeneratorTest {
|
||||
|
||||
@Test
|
||||
fun testGetColumnName() {
|
||||
val column = "test-column"
|
||||
val generator =
|
||||
SnowflakeColumnNameGenerator(mockk { every { legacyRawTablesOnly } returns false })
|
||||
val columnName = generator.getColumnName(column)
|
||||
assertEquals(column.toSnowflakeCompatibleName(), columnName.displayName)
|
||||
assertEquals(column.toSnowflakeCompatibleName(), columnName.canonicalName)
|
||||
}
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.db
|
||||
|
||||
import io.airbyte.cdk.load.command.DestinationStream
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.mockk.every
|
||||
import io.mockk.mockk
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
internal class SnowflakeFinalTableNameGeneratorTest {
|
||||
|
||||
@Test
|
||||
fun testGetTableNameWithInternalNamespace() {
|
||||
val configuration =
|
||||
mockk<SnowflakeConfiguration> {
|
||||
every { internalTableSchema } returns "test-internal-namespace"
|
||||
every { legacyRawTablesOnly } returns true
|
||||
}
|
||||
val generator = SnowflakeFinalTableNameGenerator(config = configuration)
|
||||
val streamName = "test-stream-name"
|
||||
val streamNamespace = "test-stream-namespace"
|
||||
val streamDescriptor =
|
||||
mockk<DestinationStream.Descriptor> {
|
||||
every { namespace } returns streamNamespace
|
||||
every { name } returns streamName
|
||||
}
|
||||
val tableName = generator.getTableName(streamDescriptor)
|
||||
assertEquals("test-stream-namespace_raw__stream_test-stream-name", tableName.name)
|
||||
assertEquals("test-internal-namespace", tableName.namespace)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testGetTableNameWithNamespace() {
|
||||
val configuration =
|
||||
mockk<SnowflakeConfiguration> { every { legacyRawTablesOnly } returns false }
|
||||
val generator = SnowflakeFinalTableNameGenerator(config = configuration)
|
||||
val streamName = "test-stream-name"
|
||||
val streamNamespace = "test-stream-namespace"
|
||||
val streamDescriptor =
|
||||
mockk<DestinationStream.Descriptor> {
|
||||
every { namespace } returns streamNamespace
|
||||
every { name } returns streamName
|
||||
}
|
||||
val tableName = generator.getTableName(streamDescriptor)
|
||||
assertEquals("TEST-STREAM-NAME", tableName.name)
|
||||
assertEquals("TEST-STREAM-NAMESPACE", tableName.namespace)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testGetTableNameWithDefaultNamespace() {
|
||||
val defaultNamespace = "test-default-namespace"
|
||||
val configuration =
|
||||
mockk<SnowflakeConfiguration> {
|
||||
every { schema } returns defaultNamespace
|
||||
every { legacyRawTablesOnly } returns false
|
||||
}
|
||||
val generator = SnowflakeFinalTableNameGenerator(config = configuration)
|
||||
val streamName = "test-stream-name"
|
||||
val streamDescriptor =
|
||||
mockk<DestinationStream.Descriptor> {
|
||||
every { namespace } returns null
|
||||
every { name } returns streamName
|
||||
}
|
||||
val tableName = generator.getTableName(streamDescriptor)
|
||||
assertEquals("TEST-STREAM-NAME", tableName.name)
|
||||
assertEquals("TEST-DEFAULT-NAMESPACE", tableName.namespace)
|
||||
}
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.db
|
||||
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
import org.junit.jupiter.params.ParameterizedTest
|
||||
import org.junit.jupiter.params.provider.CsvSource
|
||||
|
||||
internal class SnowflakeNameGeneratorsTest {
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvSource(
|
||||
value =
|
||||
[
|
||||
"test-name,TEST-NAME",
|
||||
"1-test-name,1-TEST-NAME",
|
||||
"test-name!!!,TEST-NAME!!!",
|
||||
"test\${name,TEST__NAME",
|
||||
"test\"name,TEST\"\"NAME",
|
||||
]
|
||||
)
|
||||
fun testToSnowflakeCompatibleName(name: String, expected: String) {
|
||||
assertEquals(expected, name.toSnowflakeCompatibleName())
|
||||
}
|
||||
}
|
||||
@@ -1,358 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.sql
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode
|
||||
import io.airbyte.cdk.load.data.ArrayType
|
||||
import io.airbyte.cdk.load.data.ArrayTypeWithoutSchema
|
||||
import io.airbyte.cdk.load.data.BooleanType
|
||||
import io.airbyte.cdk.load.data.DateType
|
||||
import io.airbyte.cdk.load.data.FieldType
|
||||
import io.airbyte.cdk.load.data.IntegerType
|
||||
import io.airbyte.cdk.load.data.NumberType
|
||||
import io.airbyte.cdk.load.data.ObjectType
|
||||
import io.airbyte.cdk.load.data.ObjectTypeWithEmptySchema
|
||||
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
|
||||
import io.airbyte.cdk.load.data.StringType
|
||||
import io.airbyte.cdk.load.data.TimeTypeWithTimezone
|
||||
import io.airbyte.cdk.load.data.TimeTypeWithoutTimezone
|
||||
import io.airbyte.cdk.load.data.TimestampTypeWithTimezone
|
||||
import io.airbyte.cdk.load.data.TimestampTypeWithoutTimezone
|
||||
import io.airbyte.cdk.load.data.UnionType
|
||||
import io.airbyte.cdk.load.data.UnknownType
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA
|
||||
import io.airbyte.cdk.load.orchestration.db.ColumnNameGenerator
|
||||
import io.airbyte.cdk.load.table.CDC_DELETED_AT_COLUMN
|
||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
||||
import io.airbyte.integrations.destination.snowflake.db.SnowflakeColumnNameGenerator
|
||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.mockk.every
|
||||
import io.mockk.mockk
|
||||
import kotlin.collections.LinkedHashMap
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
import org.junit.jupiter.api.BeforeEach
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.junit.jupiter.params.ParameterizedTest
|
||||
import org.junit.jupiter.params.provider.CsvSource
|
||||
|
||||
internal class SnowflakeColumnUtilsTest {
|
||||
|
||||
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
|
||||
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
|
||||
private lateinit var snowflakeColumnNameGenerator: SnowflakeColumnNameGenerator
|
||||
|
||||
@BeforeEach
|
||||
fun setup() {
|
||||
snowflakeConfiguration = mockk(relaxed = true)
|
||||
snowflakeColumnNameGenerator =
|
||||
mockk(relaxed = true) {
|
||||
every { getColumnName(any()) } answers
|
||||
{
|
||||
val displayName =
|
||||
if (snowflakeConfiguration.legacyRawTablesOnly) firstArg<String>()
|
||||
else firstArg<String>().toSnowflakeCompatibleName()
|
||||
val canonicalName =
|
||||
if (snowflakeConfiguration.legacyRawTablesOnly) firstArg<String>()
|
||||
else firstArg<String>().toSnowflakeCompatibleName()
|
||||
ColumnNameGenerator.ColumnName(
|
||||
displayName = displayName,
|
||||
canonicalName = canonicalName,
|
||||
)
|
||||
}
|
||||
}
|
||||
snowflakeColumnUtils =
|
||||
SnowflakeColumnUtils(snowflakeConfiguration, snowflakeColumnNameGenerator)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDefaultColumns() {
|
||||
val expectedDefaultColumns = DEFAULT_COLUMNS
|
||||
assertEquals(expectedDefaultColumns, snowflakeColumnUtils.defaultColumns())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDefaultRawColumns() {
|
||||
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
|
||||
|
||||
val expectedDefaultColumns = DEFAULT_COLUMNS + RAW_COLUMNS
|
||||
|
||||
assertEquals(expectedDefaultColumns, snowflakeColumnUtils.defaultColumns())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testGetFormattedDefaultColumnNames() {
|
||||
val expectedDefaultColumnNames =
|
||||
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() }
|
||||
val defaultColumnNames = snowflakeColumnUtils.getFormattedDefaultColumnNames()
|
||||
assertEquals(expectedDefaultColumnNames, defaultColumnNames)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testGetFormattedDefaultColumnNamesQuoted() {
|
||||
val expectedDefaultColumnNames =
|
||||
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName().quote() }
|
||||
val defaultColumnNames = snowflakeColumnUtils.getFormattedDefaultColumnNames(true)
|
||||
assertEquals(expectedDefaultColumnNames, defaultColumnNames)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testGetColumnName() {
|
||||
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
|
||||
val columnNames = snowflakeColumnUtils.getColumnNames(columnNameMapping)
|
||||
val expectedColumnNames =
|
||||
(DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() } + listOf("actual"))
|
||||
.joinToString(",") { it.quote() }
|
||||
assertEquals(expectedColumnNames, columnNames)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testGetRawColumnName() {
|
||||
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
|
||||
|
||||
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
|
||||
val columnNames = snowflakeColumnUtils.getColumnNames(columnNameMapping)
|
||||
val expectedColumnNames =
|
||||
(DEFAULT_COLUMNS.map { it.columnName } + RAW_COLUMNS.map { it.columnName })
|
||||
.joinToString(",") { it.quote() }
|
||||
assertEquals(expectedColumnNames, columnNames)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testGetRawFormattedColumnNames() {
|
||||
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
|
||||
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
|
||||
val schemaColumns =
|
||||
mapOf(
|
||||
"column_one" to FieldType(StringType, true),
|
||||
"column_two" to FieldType(IntegerType, true),
|
||||
"original" to FieldType(StringType, true),
|
||||
CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true)
|
||||
)
|
||||
val expectedColumnNames =
|
||||
DEFAULT_COLUMNS.map { it.columnName.quote() } +
|
||||
RAW_COLUMNS.map { it.columnName.quote() }
|
||||
|
||||
val columnNames =
|
||||
snowflakeColumnUtils.getFormattedColumnNames(
|
||||
columns = schemaColumns,
|
||||
columnNameMapping = columnNameMapping
|
||||
)
|
||||
assertEquals(expectedColumnNames.size, columnNames.size)
|
||||
assertEquals(expectedColumnNames.sorted(), columnNames.sorted())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testGetFormattedColumnNames() {
|
||||
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
|
||||
val schemaColumns =
|
||||
mapOf(
|
||||
"column_one" to FieldType(StringType, true),
|
||||
"column_two" to FieldType(IntegerType, true),
|
||||
"original" to FieldType(StringType, true),
|
||||
CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true)
|
||||
)
|
||||
val expectedColumnNames =
|
||||
listOf(
|
||||
"actual",
|
||||
"column_one",
|
||||
"column_two",
|
||||
CDC_DELETED_AT_COLUMN,
|
||||
)
|
||||
.map { it.quote() } +
|
||||
DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName().quote() }
|
||||
val columnNames =
|
||||
snowflakeColumnUtils.getFormattedColumnNames(
|
||||
columns = schemaColumns,
|
||||
columnNameMapping = columnNameMapping
|
||||
)
|
||||
assertEquals(expectedColumnNames.size, columnNames.size)
|
||||
assertEquals(expectedColumnNames.sorted(), columnNames.sorted())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testGetFormattedColumnNamesNoQuotes() {
|
||||
val columnNameMapping = ColumnNameMapping(mapOf("original" to "actual"))
|
||||
val schemaColumns =
|
||||
mapOf(
|
||||
"column_one" to FieldType(StringType, true),
|
||||
"column_two" to FieldType(IntegerType, true),
|
||||
"original" to FieldType(StringType, true),
|
||||
CDC_DELETED_AT_COLUMN to FieldType(TimestampTypeWithTimezone, true)
|
||||
)
|
||||
val expectedColumnNames =
|
||||
listOf(
|
||||
"actual",
|
||||
"column_one",
|
||||
"column_two",
|
||||
CDC_DELETED_AT_COLUMN,
|
||||
) + DEFAULT_COLUMNS.map { it.columnName.toSnowflakeCompatibleName() }
|
||||
val columnNames =
|
||||
snowflakeColumnUtils.getFormattedColumnNames(
|
||||
columns = schemaColumns,
|
||||
columnNameMapping = columnNameMapping,
|
||||
quote = false
|
||||
)
|
||||
assertEquals(expectedColumnNames.size, columnNames.size)
|
||||
assertEquals(expectedColumnNames.sorted(), columnNames.sorted())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testGeneratingRawTableColumnsAndTypesNoColumnMapping() {
|
||||
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
|
||||
|
||||
val columns =
|
||||
snowflakeColumnUtils.columnsAndTypes(
|
||||
columns = emptyMap(),
|
||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
||||
)
|
||||
assertEquals(DEFAULT_COLUMNS.size + RAW_COLUMNS.size, columns.size)
|
||||
assertEquals(
|
||||
"${SnowflakeDataType.VARIANT.typeName} $NOT_NULL",
|
||||
columns.find { it.columnName == RAW_DATA_COLUMN.columnName }?.columnType
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testGeneratingColumnsAndTypesNoColumnMapping() {
|
||||
val columnName = "test-column"
|
||||
val fieldType = FieldType(StringType, false)
|
||||
val declaredColumns = mapOf(columnName to fieldType)
|
||||
|
||||
val columns =
|
||||
snowflakeColumnUtils.columnsAndTypes(
|
||||
columns = declaredColumns,
|
||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
||||
)
|
||||
assertEquals(DEFAULT_COLUMNS.size + 1, columns.size)
|
||||
assertEquals(
|
||||
"${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL",
|
||||
columns.find { it.columnName == columnName }?.columnType
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testGeneratingColumnsAndTypesWithColumnMapping() {
|
||||
val columnName = "test-column"
|
||||
val mappedColumnName = "mapped-column-name"
|
||||
val fieldType = FieldType(StringType, false)
|
||||
val declaredColumns = mapOf(columnName to fieldType)
|
||||
val columnNameMapping = ColumnNameMapping(mapOf(columnName to mappedColumnName))
|
||||
|
||||
val columns =
|
||||
snowflakeColumnUtils.columnsAndTypes(
|
||||
columns = declaredColumns,
|
||||
columnNameMapping = columnNameMapping
|
||||
)
|
||||
assertEquals(DEFAULT_COLUMNS.size + 1, columns.size)
|
||||
assertEquals(
|
||||
"${SnowflakeDataType.VARCHAR.typeName} $NOT_NULL",
|
||||
columns.find { it.columnName == mappedColumnName }?.columnType
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testToDialectType() {
|
||||
assertEquals(
|
||||
SnowflakeDataType.BOOLEAN.typeName,
|
||||
snowflakeColumnUtils.toDialectType(BooleanType)
|
||||
)
|
||||
assertEquals(SnowflakeDataType.DATE.typeName, snowflakeColumnUtils.toDialectType(DateType))
|
||||
assertEquals(
|
||||
SnowflakeDataType.NUMBER.typeName,
|
||||
snowflakeColumnUtils.toDialectType(IntegerType)
|
||||
)
|
||||
assertEquals(
|
||||
SnowflakeDataType.FLOAT.typeName,
|
||||
snowflakeColumnUtils.toDialectType(NumberType)
|
||||
)
|
||||
assertEquals(
|
||||
SnowflakeDataType.VARCHAR.typeName,
|
||||
snowflakeColumnUtils.toDialectType(StringType)
|
||||
)
|
||||
assertEquals(
|
||||
SnowflakeDataType.VARCHAR.typeName,
|
||||
snowflakeColumnUtils.toDialectType(TimeTypeWithTimezone)
|
||||
)
|
||||
assertEquals(
|
||||
SnowflakeDataType.TIME.typeName,
|
||||
snowflakeColumnUtils.toDialectType(TimeTypeWithoutTimezone)
|
||||
)
|
||||
assertEquals(
|
||||
SnowflakeDataType.TIMESTAMP_TZ.typeName,
|
||||
snowflakeColumnUtils.toDialectType(TimestampTypeWithTimezone)
|
||||
)
|
||||
assertEquals(
|
||||
SnowflakeDataType.TIMESTAMP_NTZ.typeName,
|
||||
snowflakeColumnUtils.toDialectType(TimestampTypeWithoutTimezone)
|
||||
)
|
||||
assertEquals(
|
||||
SnowflakeDataType.ARRAY.typeName,
|
||||
snowflakeColumnUtils.toDialectType(ArrayType(items = FieldType(StringType, false)))
|
||||
)
|
||||
assertEquals(
|
||||
SnowflakeDataType.ARRAY.typeName,
|
||||
snowflakeColumnUtils.toDialectType(ArrayTypeWithoutSchema)
|
||||
)
|
||||
assertEquals(
|
||||
SnowflakeDataType.OBJECT.typeName,
|
||||
snowflakeColumnUtils.toDialectType(
|
||||
ObjectType(
|
||||
properties = LinkedHashMap(),
|
||||
additionalProperties = false,
|
||||
)
|
||||
)
|
||||
)
|
||||
assertEquals(
|
||||
SnowflakeDataType.OBJECT.typeName,
|
||||
snowflakeColumnUtils.toDialectType(ObjectTypeWithEmptySchema)
|
||||
)
|
||||
assertEquals(
|
||||
SnowflakeDataType.OBJECT.typeName,
|
||||
snowflakeColumnUtils.toDialectType(ObjectTypeWithoutSchema)
|
||||
)
|
||||
assertEquals(
|
||||
SnowflakeDataType.VARIANT.typeName,
|
||||
snowflakeColumnUtils.toDialectType(
|
||||
UnionType(
|
||||
options = setOf(StringType),
|
||||
isLegacyUnion = true,
|
||||
)
|
||||
)
|
||||
)
|
||||
assertEquals(
|
||||
SnowflakeDataType.VARIANT.typeName,
|
||||
snowflakeColumnUtils.toDialectType(
|
||||
UnionType(
|
||||
options = emptySet(),
|
||||
isLegacyUnion = false,
|
||||
)
|
||||
)
|
||||
)
|
||||
assertEquals(
|
||||
SnowflakeDataType.VARIANT.typeName,
|
||||
snowflakeColumnUtils.toDialectType(UnknownType(schema = mockk<JsonNode>()))
|
||||
)
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvSource(
|
||||
value =
|
||||
[
|
||||
"$COLUMN_NAME_DATA, true, \"$COLUMN_NAME_DATA\"",
|
||||
"some-other_Column, true, \"SOME-OTHER_COLUMN\"",
|
||||
"$COLUMN_NAME_DATA, false, $COLUMN_NAME_DATA",
|
||||
"some-other_Column, false, SOME-OTHER_COLUMN",
|
||||
"$COLUMN_NAME_DATA, true, \"$COLUMN_NAME_DATA\"",
|
||||
"some-other_Column, true, \"SOME-OTHER_COLUMN\"",
|
||||
]
|
||||
)
|
||||
fun testFormatColumnName(columnName: String, quote: Boolean, expectedFormattedName: String) {
|
||||
assertEquals(
|
||||
expectedFormattedName,
|
||||
snowflakeColumnUtils.formatColumnName(columnName, quote)
|
||||
)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,97 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.sql
|
||||
|
||||
import io.airbyte.cdk.load.table.TableName
|
||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.mockk.every
|
||||
import io.mockk.mockk
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
import org.junit.jupiter.api.BeforeEach
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
internal class SnowflakeSqlNameUtilsTest {
|
||||
|
||||
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
|
||||
private lateinit var snowflakeSqlNameUtils: SnowflakeSqlNameUtils
|
||||
|
||||
@BeforeEach
|
||||
fun setUp() {
|
||||
snowflakeConfiguration = mockk(relaxed = true)
|
||||
snowflakeSqlNameUtils =
|
||||
SnowflakeSqlNameUtils(snowflakeConfiguration = snowflakeConfiguration)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testFullyQualifiedName() {
|
||||
val databaseName = "test-database"
|
||||
val namespace = "test-namespace"
|
||||
val name = "test=name"
|
||||
val tableName = TableName(namespace = namespace, name = name)
|
||||
every { snowflakeConfiguration.database } returns databaseName
|
||||
|
||||
val expectedName =
|
||||
snowflakeSqlNameUtils.combineParts(
|
||||
listOf(
|
||||
databaseName.toSnowflakeCompatibleName(),
|
||||
tableName.namespace,
|
||||
tableName.name
|
||||
)
|
||||
)
|
||||
val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedName(tableName)
|
||||
assertEquals(expectedName, fullyQualifiedName)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testFullyQualifiedNamespace() {
|
||||
val databaseName = "test-database"
|
||||
val namespace = "test-namespace"
|
||||
every { snowflakeConfiguration.database } returns databaseName
|
||||
|
||||
val fullyQualifiedNamespace = snowflakeSqlNameUtils.fullyQualifiedNamespace(namespace)
|
||||
assertEquals("\"TEST-DATABASE\".\"test-namespace\"", fullyQualifiedNamespace)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testFullyQualifiedStageName() {
|
||||
val databaseName = "test-database"
|
||||
val namespace = "test-namespace"
|
||||
val name = "test=name"
|
||||
val tableName = TableName(namespace = namespace, name = name)
|
||||
every { snowflakeConfiguration.database } returns databaseName
|
||||
|
||||
val expectedName =
|
||||
snowflakeSqlNameUtils.combineParts(
|
||||
listOf(
|
||||
databaseName.toSnowflakeCompatibleName(),
|
||||
namespace,
|
||||
"$STAGE_NAME_PREFIX$name"
|
||||
)
|
||||
)
|
||||
val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName)
|
||||
assertEquals(expectedName, fullyQualifiedName)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testFullyQualifiedStageNameWithEscape() {
|
||||
val databaseName = "test-database"
|
||||
val namespace = "test-namespace"
|
||||
val name = "test=\"\"\'name"
|
||||
val tableName = TableName(namespace = namespace, name = name)
|
||||
every { snowflakeConfiguration.database } returns databaseName
|
||||
|
||||
val expectedName =
|
||||
snowflakeSqlNameUtils.combineParts(
|
||||
listOf(
|
||||
databaseName.toSnowflakeCompatibleName(),
|
||||
namespace,
|
||||
"$STAGE_NAME_PREFIX${sqlEscape(name)}"
|
||||
)
|
||||
)
|
||||
val fullyQualifiedName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true)
|
||||
assertEquals(expectedName, fullyQualifiedName)
|
||||
}
|
||||
}
|
||||
@@ -6,20 +6,21 @@ package io.airbyte.integrations.destination.snowflake.write
|
||||
|
||||
import io.airbyte.cdk.SystemErrorException
|
||||
import io.airbyte.cdk.load.command.Append
|
||||
import io.airbyte.cdk.load.command.DestinationCatalog
|
||||
import io.airbyte.cdk.load.command.DestinationStream
|
||||
import io.airbyte.cdk.load.orchestration.db.DatabaseInitialStatusGatherer
|
||||
import io.airbyte.cdk.load.orchestration.db.TableNames
|
||||
import io.airbyte.cdk.load.orchestration.db.TempTableNameGenerator
|
||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadInitialStatus
|
||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendStreamLoader
|
||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableAppendTruncateStreamLoader
|
||||
import io.airbyte.cdk.load.orchestration.db.direct_load_table.DirectLoadTableStatus
|
||||
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableCatalog
|
||||
import io.airbyte.cdk.load.orchestration.db.legacy_typing_deduping.TableNameInfo
|
||||
import io.airbyte.cdk.load.table.ColumnNameMapping
|
||||
import io.airbyte.cdk.load.table.TableName
|
||||
import io.airbyte.cdk.load.schema.model.ColumnSchema
|
||||
import io.airbyte.cdk.load.schema.model.StreamTableSchema
|
||||
import io.airbyte.cdk.load.schema.model.TableName
|
||||
import io.airbyte.cdk.load.schema.model.TableNames
|
||||
import io.airbyte.cdk.load.table.DatabaseInitialStatusGatherer
|
||||
import io.airbyte.cdk.load.table.TempTableNameGenerator
|
||||
import io.airbyte.cdk.load.table.directload.DirectLoadInitialStatus
|
||||
import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendStreamLoader
|
||||
import io.airbyte.cdk.load.table.directload.DirectLoadTableAppendTruncateStreamLoader
|
||||
import io.airbyte.cdk.load.table.directload.DirectLoadTableExecutionConfig
|
||||
import io.airbyte.cdk.load.table.directload.DirectLoadTableStatus
|
||||
import io.airbyte.cdk.load.write.StreamStateStore
|
||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
||||
import io.mockk.coEvery
|
||||
import io.mockk.coVerify
|
||||
import io.mockk.every
|
||||
@@ -34,55 +35,93 @@ internal class SnowflakeWriterTest {
|
||||
@Test
|
||||
fun testSetup() {
|
||||
val tableName = TableName(namespace = "test-namespace", name = "test-name")
|
||||
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
|
||||
val stream = mockk<DestinationStream>()
|
||||
val tableInfo =
|
||||
TableNameInfo(
|
||||
tableNames = tableNames,
|
||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
||||
)
|
||||
val catalog = TableCatalog(mapOf(stream to tableInfo))
|
||||
val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
|
||||
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
|
||||
val stream =
|
||||
mockk<DestinationStream> {
|
||||
every { tableSchema } returns
|
||||
StreamTableSchema(
|
||||
tableNames = tableNames,
|
||||
columnSchema =
|
||||
ColumnSchema(
|
||||
inputToFinalColumnNames = emptyMap(),
|
||||
finalSchema = emptyMap(),
|
||||
inputSchema = emptyMap()
|
||||
),
|
||||
importType = Append
|
||||
)
|
||||
every { mappedDescriptor } returns
|
||||
DestinationStream.Descriptor(
|
||||
namespace = tableName.namespace,
|
||||
name = tableName.name
|
||||
)
|
||||
every { importType } returns Append
|
||||
}
|
||||
val catalog = DestinationCatalog(listOf(stream))
|
||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||
val stateGatherer =
|
||||
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
|
||||
coEvery { gatherInitialStatus(catalog) } returns emptyMap()
|
||||
coEvery { gatherInitialStatus() } returns
|
||||
mapOf(
|
||||
stream to
|
||||
DirectLoadInitialStatus(
|
||||
realTable = DirectLoadTableStatus(false),
|
||||
tempTable = null
|
||||
)
|
||||
)
|
||||
}
|
||||
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
|
||||
val writer =
|
||||
SnowflakeWriter(
|
||||
names = catalog,
|
||||
catalog = catalog,
|
||||
stateGatherer = stateGatherer,
|
||||
streamStateStore = mockk(),
|
||||
streamStateStore = streamStateStore,
|
||||
snowflakeClient = snowflakeClient,
|
||||
tempTableNameGenerator = mockk(),
|
||||
snowflakeConfiguration = mockk(relaxed = true),
|
||||
snowflakeConfiguration =
|
||||
mockk(relaxed = true) {
|
||||
every { internalTableSchema } returns "internal_schema"
|
||||
},
|
||||
)
|
||||
|
||||
runBlocking { writer.setup() }
|
||||
|
||||
coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) }
|
||||
coVerify(exactly = 1) { stateGatherer.gatherInitialStatus(catalog) }
|
||||
coVerify(exactly = 1) { stateGatherer.gatherInitialStatus() }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCreateStreamLoaderFirstGeneration() {
|
||||
val tableName = TableName(namespace = "test-namespace", name = "test-name")
|
||||
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
|
||||
val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
|
||||
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
|
||||
val stream =
|
||||
mockk<DestinationStream> {
|
||||
every { minimumGenerationId } returns 0L
|
||||
every { generationId } returns 0L
|
||||
every { importType } returns Append
|
||||
every { tableSchema } returns
|
||||
StreamTableSchema(
|
||||
tableNames = tableNames,
|
||||
columnSchema =
|
||||
ColumnSchema(
|
||||
inputToFinalColumnNames = emptyMap(),
|
||||
finalSchema = emptyMap(),
|
||||
inputSchema = emptyMap()
|
||||
),
|
||||
importType = Append
|
||||
)
|
||||
every { mappedDescriptor } returns
|
||||
DestinationStream.Descriptor(
|
||||
namespace = tableName.namespace,
|
||||
name = tableName.name
|
||||
)
|
||||
}
|
||||
val tableInfo =
|
||||
TableNameInfo(
|
||||
tableNames = tableNames,
|
||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
||||
)
|
||||
val catalog = TableCatalog(mapOf(stream to tableInfo))
|
||||
val catalog = DestinationCatalog(listOf(stream))
|
||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||
val stateGatherer =
|
||||
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
|
||||
coEvery { gatherInitialStatus(catalog) } returns
|
||||
coEvery { gatherInitialStatus() } returns
|
||||
mapOf(
|
||||
stream to
|
||||
DirectLoadInitialStatus(
|
||||
@@ -93,14 +132,18 @@ internal class SnowflakeWriterTest {
|
||||
}
|
||||
val tempTableNameGenerator =
|
||||
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
|
||||
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
|
||||
val writer =
|
||||
SnowflakeWriter(
|
||||
names = catalog,
|
||||
catalog = catalog,
|
||||
stateGatherer = stateGatherer,
|
||||
streamStateStore = mockk(),
|
||||
streamStateStore = streamStateStore,
|
||||
snowflakeClient = snowflakeClient,
|
||||
tempTableNameGenerator = tempTableNameGenerator,
|
||||
snowflakeConfiguration = mockk(relaxed = true),
|
||||
snowflakeConfiguration =
|
||||
mockk(relaxed = true) {
|
||||
every { internalTableSchema } returns "internal_schema"
|
||||
},
|
||||
)
|
||||
|
||||
runBlocking {
|
||||
@@ -113,23 +156,35 @@ internal class SnowflakeWriterTest {
|
||||
@Test
|
||||
fun testCreateStreamLoaderNotFirstGeneration() {
|
||||
val tableName = TableName(namespace = "test-namespace", name = "test-name")
|
||||
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
|
||||
val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
|
||||
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
|
||||
val stream =
|
||||
mockk<DestinationStream> {
|
||||
every { minimumGenerationId } returns 1L
|
||||
every { generationId } returns 1L
|
||||
every { importType } returns Append
|
||||
every { tableSchema } returns
|
||||
StreamTableSchema(
|
||||
tableNames = tableNames,
|
||||
columnSchema =
|
||||
ColumnSchema(
|
||||
inputToFinalColumnNames = emptyMap(),
|
||||
finalSchema = emptyMap(),
|
||||
inputSchema = emptyMap()
|
||||
),
|
||||
importType = Append
|
||||
)
|
||||
every { mappedDescriptor } returns
|
||||
DestinationStream.Descriptor(
|
||||
namespace = tableName.namespace,
|
||||
name = tableName.name
|
||||
)
|
||||
}
|
||||
val tableInfo =
|
||||
TableNameInfo(
|
||||
tableNames = tableNames,
|
||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
||||
)
|
||||
val catalog = TableCatalog(mapOf(stream to tableInfo))
|
||||
val catalog = DestinationCatalog(listOf(stream))
|
||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||
val stateGatherer =
|
||||
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
|
||||
coEvery { gatherInitialStatus(catalog) } returns
|
||||
coEvery { gatherInitialStatus() } returns
|
||||
mapOf(
|
||||
stream to
|
||||
DirectLoadInitialStatus(
|
||||
@@ -140,14 +195,18 @@ internal class SnowflakeWriterTest {
|
||||
}
|
||||
val tempTableNameGenerator =
|
||||
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
|
||||
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
|
||||
val writer =
|
||||
SnowflakeWriter(
|
||||
names = catalog,
|
||||
catalog = catalog,
|
||||
stateGatherer = stateGatherer,
|
||||
streamStateStore = mockk(),
|
||||
streamStateStore = streamStateStore,
|
||||
snowflakeClient = snowflakeClient,
|
||||
tempTableNameGenerator = tempTableNameGenerator,
|
||||
snowflakeConfiguration = mockk(relaxed = true),
|
||||
snowflakeConfiguration =
|
||||
mockk(relaxed = true) {
|
||||
every { internalTableSchema } returns "internal_schema"
|
||||
},
|
||||
)
|
||||
|
||||
runBlocking {
|
||||
@@ -160,22 +219,35 @@ internal class SnowflakeWriterTest {
|
||||
@Test
|
||||
fun testCreateStreamLoaderHybrid() {
|
||||
val tableName = TableName(namespace = "test-namespace", name = "test-name")
|
||||
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
|
||||
val tempTableName = TableName(namespace = "test-namespace", name = "test-name-temp")
|
||||
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
|
||||
val stream =
|
||||
mockk<DestinationStream> {
|
||||
every { minimumGenerationId } returns 1L
|
||||
every { generationId } returns 2L
|
||||
every { importType } returns Append
|
||||
every { tableSchema } returns
|
||||
StreamTableSchema(
|
||||
tableNames = tableNames,
|
||||
columnSchema =
|
||||
ColumnSchema(
|
||||
inputToFinalColumnNames = emptyMap(),
|
||||
finalSchema = emptyMap(),
|
||||
inputSchema = emptyMap()
|
||||
),
|
||||
importType = Append
|
||||
)
|
||||
every { mappedDescriptor } returns
|
||||
DestinationStream.Descriptor(
|
||||
namespace = tableName.namespace,
|
||||
name = tableName.name
|
||||
)
|
||||
}
|
||||
val tableInfo =
|
||||
TableNameInfo(
|
||||
tableNames = tableNames,
|
||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
||||
)
|
||||
val catalog = TableCatalog(mapOf(stream to tableInfo))
|
||||
val catalog = DestinationCatalog(listOf(stream))
|
||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||
val stateGatherer =
|
||||
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
|
||||
coEvery { gatherInitialStatus(catalog) } returns
|
||||
coEvery { gatherInitialStatus() } returns
|
||||
mapOf(
|
||||
stream to
|
||||
DirectLoadInitialStatus(
|
||||
@@ -186,14 +258,18 @@ internal class SnowflakeWriterTest {
|
||||
}
|
||||
val tempTableNameGenerator =
|
||||
mockk<TempTableNameGenerator> { every { generate(any()) } answers { firstArg() } }
|
||||
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
|
||||
val writer =
|
||||
SnowflakeWriter(
|
||||
names = catalog,
|
||||
catalog = catalog,
|
||||
stateGatherer = stateGatherer,
|
||||
streamStateStore = mockk(),
|
||||
streamStateStore = streamStateStore,
|
||||
snowflakeClient = snowflakeClient,
|
||||
tempTableNameGenerator = tempTableNameGenerator,
|
||||
snowflakeConfiguration = mockk(relaxed = true),
|
||||
snowflakeConfiguration =
|
||||
mockk(relaxed = true) {
|
||||
every { internalTableSchema } returns "internal_schema"
|
||||
},
|
||||
)
|
||||
|
||||
runBlocking {
|
||||
@@ -203,169 +279,126 @@ internal class SnowflakeWriterTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSetupWithNamespaceCreationFailure() {
|
||||
val tableName = TableName(namespace = "test-namespace", name = "test-name")
|
||||
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
|
||||
val stream = mockk<DestinationStream>()
|
||||
val tableInfo =
|
||||
TableNameInfo(
|
||||
tableNames = tableNames,
|
||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
||||
)
|
||||
val catalog = TableCatalog(mapOf(stream to tableInfo))
|
||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>()
|
||||
val stateGatherer = mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>>()
|
||||
val writer =
|
||||
SnowflakeWriter(
|
||||
names = catalog,
|
||||
stateGatherer = stateGatherer,
|
||||
streamStateStore = mockk(),
|
||||
snowflakeClient = snowflakeClient,
|
||||
tempTableNameGenerator = mockk(),
|
||||
snowflakeConfiguration = mockk(),
|
||||
)
|
||||
|
||||
// Simulate network failure during namespace creation
|
||||
coEvery {
|
||||
snowflakeClient.createNamespace(tableName.namespace.toSnowflakeCompatibleName())
|
||||
} throws RuntimeException("Network connection failed")
|
||||
|
||||
assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSetupWithInitialStatusGatheringFailure() {
|
||||
val tableName = TableName(namespace = "test-namespace", name = "test-name")
|
||||
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
|
||||
val stream = mockk<DestinationStream>()
|
||||
val tableInfo =
|
||||
TableNameInfo(
|
||||
tableNames = tableNames,
|
||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
||||
)
|
||||
val catalog = TableCatalog(mapOf(stream to tableInfo))
|
||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||
val stateGatherer = mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>>()
|
||||
val writer =
|
||||
SnowflakeWriter(
|
||||
names = catalog,
|
||||
stateGatherer = stateGatherer,
|
||||
streamStateStore = mockk(),
|
||||
snowflakeClient = snowflakeClient,
|
||||
tempTableNameGenerator = mockk(),
|
||||
snowflakeConfiguration = mockk(),
|
||||
)
|
||||
|
||||
// Simulate failure while gathering initial status
|
||||
coEvery { stateGatherer.gatherInitialStatus(catalog) } throws
|
||||
RuntimeException("Failed to query table status")
|
||||
|
||||
assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } }
|
||||
|
||||
// Verify namespace creation was still attempted
|
||||
coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCreateStreamLoaderWithMissingInitialStatus() {
|
||||
val tableName = TableName(namespace = "test-namespace", name = "test-name")
|
||||
val tableNames = TableNames(rawTableName = null, finalTableName = tableName)
|
||||
fun testCreateStreamLoaderNamespaceLegacy() {
|
||||
val namespace = "test-namespace"
|
||||
val name = "test-name"
|
||||
val tableName = TableName(namespace = namespace, name = name)
|
||||
val tempTableName = TableName(namespace = namespace, name = "${name}-temp")
|
||||
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
|
||||
val stream =
|
||||
mockk<DestinationStream> {
|
||||
every { minimumGenerationId } returns 0L
|
||||
every { generationId } returns 0L
|
||||
every { importType } returns Append
|
||||
every { tableSchema } returns
|
||||
StreamTableSchema(
|
||||
tableNames = tableNames,
|
||||
columnSchema =
|
||||
ColumnSchema(
|
||||
inputToFinalColumnNames = emptyMap(),
|
||||
finalSchema = emptyMap(),
|
||||
inputSchema = emptyMap()
|
||||
),
|
||||
importType = Append
|
||||
)
|
||||
every { mappedDescriptor } returns
|
||||
DestinationStream.Descriptor(
|
||||
namespace = tableName.namespace,
|
||||
name = tableName.name
|
||||
)
|
||||
}
|
||||
val missingStream =
|
||||
mockk<DestinationStream> {
|
||||
every { minimumGenerationId } returns 0L
|
||||
every { generationId } returns 0L
|
||||
}
|
||||
val tableInfo =
|
||||
TableNameInfo(
|
||||
tableNames = tableNames,
|
||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
||||
)
|
||||
val catalog = TableCatalog(mapOf(stream to tableInfo))
|
||||
val catalog = DestinationCatalog(listOf(stream))
|
||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||
val stateGatherer =
|
||||
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
|
||||
coEvery { gatherInitialStatus(catalog) } returns
|
||||
coEvery { gatherInitialStatus() } returns
|
||||
mapOf(
|
||||
stream to
|
||||
DirectLoadInitialStatus(
|
||||
realTable = DirectLoadTableStatus(false),
|
||||
tempTable = null,
|
||||
tempTable = null
|
||||
)
|
||||
)
|
||||
}
|
||||
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
|
||||
val writer =
|
||||
SnowflakeWriter(
|
||||
names = catalog,
|
||||
catalog = catalog,
|
||||
stateGatherer = stateGatherer,
|
||||
streamStateStore = mockk(),
|
||||
streamStateStore = streamStateStore,
|
||||
snowflakeClient = snowflakeClient,
|
||||
tempTableNameGenerator = mockk(),
|
||||
snowflakeConfiguration = mockk(relaxed = true),
|
||||
snowflakeConfiguration =
|
||||
mockk(relaxed = true) {
|
||||
every { legacyRawTablesOnly } returns true
|
||||
every { internalTableSchema } returns "internal_schema"
|
||||
},
|
||||
)
|
||||
|
||||
runBlocking {
|
||||
writer.setup()
|
||||
// Try to create loader for a stream that wasn't in initial status
|
||||
assertThrows(NullPointerException::class.java) {
|
||||
writer.createStreamLoader(missingStream)
|
||||
runBlocking { writer.setup() }
|
||||
|
||||
coVerify(exactly = 1) { snowflakeClient.createNamespace(tableName.namespace) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCreateStreamLoaderNamespaceNonLegacy() {
|
||||
val namespace = "test-namespace"
|
||||
val name = "test-name"
|
||||
val tableName = TableName(namespace = namespace, name = name)
|
||||
val tempTableName = TableName(namespace = namespace, name = "${name}-temp")
|
||||
val tableNames = TableNames(finalTableName = tableName, tempTableName = tempTableName)
|
||||
val stream =
|
||||
mockk<DestinationStream> {
|
||||
every { minimumGenerationId } returns 0L
|
||||
every { generationId } returns 0L
|
||||
every { importType } returns Append
|
||||
every { tableSchema } returns
|
||||
StreamTableSchema(
|
||||
tableNames = tableNames,
|
||||
columnSchema =
|
||||
ColumnSchema(
|
||||
inputToFinalColumnNames = emptyMap(),
|
||||
finalSchema = emptyMap(),
|
||||
inputSchema = emptyMap()
|
||||
),
|
||||
importType = Append
|
||||
)
|
||||
every { mappedDescriptor } returns
|
||||
DestinationStream.Descriptor(
|
||||
namespace = tableName.namespace,
|
||||
name = tableName.name
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCreateStreamLoaderWithNullFinalTableName() {
|
||||
// TableNames constructor throws IllegalStateException when both names are null
|
||||
assertThrows(IllegalStateException::class.java) {
|
||||
TableNames(rawTableName = null, finalTableName = null)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSetupWithMultipleNamespaceFailuresPartial() {
|
||||
val tableName1 = TableName(namespace = "namespace1", name = "table1")
|
||||
val tableName2 = TableName(namespace = "namespace2", name = "table2")
|
||||
val tableNames1 = TableNames(rawTableName = null, finalTableName = tableName1)
|
||||
val tableNames2 = TableNames(rawTableName = null, finalTableName = tableName2)
|
||||
val stream1 = mockk<DestinationStream>()
|
||||
val stream2 = mockk<DestinationStream>()
|
||||
val tableInfo1 =
|
||||
TableNameInfo(
|
||||
tableNames = tableNames1,
|
||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
||||
)
|
||||
val tableInfo2 =
|
||||
TableNameInfo(
|
||||
tableNames = tableNames2,
|
||||
columnNameMapping = ColumnNameMapping(emptyMap())
|
||||
)
|
||||
val catalog = TableCatalog(mapOf(stream1 to tableInfo1, stream2 to tableInfo2))
|
||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>()
|
||||
val stateGatherer = mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>>()
|
||||
val catalog = DestinationCatalog(listOf(stream))
|
||||
val snowflakeClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||
val stateGatherer =
|
||||
mockk<DatabaseInitialStatusGatherer<DirectLoadInitialStatus>> {
|
||||
coEvery { gatherInitialStatus() } returns
|
||||
mapOf(
|
||||
stream to
|
||||
DirectLoadInitialStatus(
|
||||
realTable = DirectLoadTableStatus(false),
|
||||
tempTable = null
|
||||
)
|
||||
)
|
||||
}
|
||||
val streamStateStore = mockk<StreamStateStore<DirectLoadTableExecutionConfig>>()
|
||||
val writer =
|
||||
SnowflakeWriter(
|
||||
names = catalog,
|
||||
catalog = catalog,
|
||||
stateGatherer = stateGatherer,
|
||||
streamStateStore = mockk(),
|
||||
streamStateStore = streamStateStore,
|
||||
snowflakeClient = snowflakeClient,
|
||||
tempTableNameGenerator = mockk(),
|
||||
snowflakeConfiguration = mockk(),
|
||||
snowflakeConfiguration =
|
||||
mockk(relaxed = true) {
|
||||
every { legacyRawTablesOnly } returns false
|
||||
every { internalTableSchema } returns "internal_schema"
|
||||
},
|
||||
)
|
||||
|
||||
// First namespace succeeds, second fails (namespaces are uppercased by
|
||||
// toSnowflakeCompatibleName)
|
||||
coEvery { snowflakeClient.createNamespace("namespace1") } returns Unit
|
||||
coEvery { snowflakeClient.createNamespace("namespace2") } throws
|
||||
RuntimeException("Connection timeout")
|
||||
runBlocking { writer.setup() }
|
||||
|
||||
assertThrows(RuntimeException::class.java) { runBlocking { writer.setup() } }
|
||||
|
||||
// Verify both namespace creations were attempted
|
||||
coVerify(exactly = 1) { snowflakeClient.createNamespace("namespace1") }
|
||||
coVerify(exactly = 1) { snowflakeClient.createNamespace("namespace2") }
|
||||
coVerify(exactly = 1) { snowflakeClient.createNamespace(namespace) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,218 +4,220 @@
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.write.load
|
||||
|
||||
import io.airbyte.cdk.load.component.ColumnType
|
||||
import io.airbyte.cdk.load.data.AirbyteValue
|
||||
import io.airbyte.cdk.load.data.FieldType
|
||||
import io.airbyte.cdk.load.data.IntegerValue
|
||||
import io.airbyte.cdk.load.data.NullValue
|
||||
import io.airbyte.cdk.load.data.StringType
|
||||
import io.airbyte.cdk.load.data.StringValue
|
||||
import io.airbyte.cdk.load.message.Meta
|
||||
import io.airbyte.cdk.load.table.TableName
|
||||
import io.airbyte.cdk.load.schema.model.ColumnSchema
|
||||
import io.airbyte.cdk.load.schema.model.TableName
|
||||
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
|
||||
import io.airbyte.integrations.destination.snowflake.schema.SnowflakeColumnManager
|
||||
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
||||
import io.mockk.coVerify
|
||||
import io.mockk.every
|
||||
import io.mockk.mockk
|
||||
import java.io.BufferedReader
|
||||
import java.io.File
|
||||
import java.io.InputStreamReader
|
||||
import java.util.zip.GZIPInputStream
|
||||
import kotlin.io.path.exists
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
import org.junit.jupiter.api.Assertions.assertNotNull
|
||||
import org.junit.jupiter.api.BeforeEach
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
internal class SnowflakeInsertBufferTest {
|
||||
|
||||
private lateinit var snowflakeConfiguration: SnowflakeConfiguration
|
||||
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
|
||||
private lateinit var columnManager: SnowflakeColumnManager
|
||||
private lateinit var columnSchema: ColumnSchema
|
||||
private lateinit var snowflakeRecordFormatter: SnowflakeRecordFormatter
|
||||
|
||||
@BeforeEach
|
||||
fun setUp() {
|
||||
snowflakeConfiguration = mockk(relaxed = true)
|
||||
snowflakeColumnUtils = mockk(relaxed = true)
|
||||
snowflakeRecordFormatter = SnowflakeSchemaRecordFormatter()
|
||||
columnManager =
|
||||
mockk(relaxed = true) {
|
||||
every { getMetaColumns() } returns
|
||||
linkedMapOf(
|
||||
"_AIRBYTE_RAW_ID" to ColumnType("VARCHAR", false),
|
||||
"_AIRBYTE_EXTRACTED_AT" to ColumnType("TIMESTAMP_TZ", false),
|
||||
"_AIRBYTE_META" to ColumnType("VARIANT", false),
|
||||
"_AIRBYTE_GENERATION_ID" to ColumnType("NUMBER", true)
|
||||
)
|
||||
every { getTableColumnNames(any()) } returns
|
||||
listOf(
|
||||
"_AIRBYTE_RAW_ID",
|
||||
"_AIRBYTE_EXTRACTED_AT",
|
||||
"_AIRBYTE_META",
|
||||
"_AIRBYTE_GENERATION_ID",
|
||||
"columnName"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAccumulate() {
|
||||
val tableName = mockk<TableName>(relaxed = true)
|
||||
val tableName = TableName(namespace = "test", name = "table")
|
||||
val column = "columnName"
|
||||
val columns = linkedMapOf(column to "NUMBER(38,0)")
|
||||
columnSchema =
|
||||
ColumnSchema(
|
||||
inputToFinalColumnNames = mapOf(column to column.uppercase()),
|
||||
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
|
||||
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
|
||||
)
|
||||
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||
val record = createRecord(column)
|
||||
val buffer =
|
||||
SnowflakeInsertBuffer(
|
||||
tableName = tableName,
|
||||
columns = columns,
|
||||
snowflakeClient = snowflakeAirbyteClient,
|
||||
snowflakeConfiguration = snowflakeConfiguration,
|
||||
flushLimit = 1,
|
||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
||||
columnSchema = columnSchema,
|
||||
columnManager = columnManager,
|
||||
snowflakeRecordFormatter = snowflakeRecordFormatter,
|
||||
)
|
||||
|
||||
buffer.accumulate(record)
|
||||
|
||||
assertEquals(true, buffer.csvFilePath?.exists())
|
||||
assertEquals(0, buffer.recordCount)
|
||||
runBlocking { buffer.accumulate(record) }
|
||||
assertEquals(1, buffer.recordCount)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAccumulateRaw() {
|
||||
val tableName = mockk<TableName>(relaxed = true)
|
||||
fun testFlushToStaging() {
|
||||
val tableName = TableName(namespace = "test", name = "table")
|
||||
val column = "columnName"
|
||||
val columns = linkedMapOf(column to "NUMBER(38,0)")
|
||||
columnSchema =
|
||||
ColumnSchema(
|
||||
inputToFinalColumnNames = mapOf(column to column.uppercase()),
|
||||
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
|
||||
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
|
||||
)
|
||||
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||
val record = createRecord(column)
|
||||
val buffer =
|
||||
SnowflakeInsertBuffer(
|
||||
tableName = tableName,
|
||||
columns = columns,
|
||||
snowflakeClient = snowflakeAirbyteClient,
|
||||
snowflakeConfiguration = snowflakeConfiguration,
|
||||
columnSchema = columnSchema,
|
||||
columnManager = columnManager,
|
||||
snowflakeRecordFormatter = snowflakeRecordFormatter,
|
||||
flushLimit = 1,
|
||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
||||
)
|
||||
|
||||
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
|
||||
|
||||
buffer.accumulate(record)
|
||||
|
||||
assertEquals(true, buffer.csvFilePath?.exists())
|
||||
assertEquals(1, buffer.recordCount)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testFlush() {
|
||||
val tableName = mockk<TableName>(relaxed = true)
|
||||
val column = "columnName"
|
||||
val columns = linkedMapOf(column to "NUMBER(38,0)")
|
||||
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||
val record = createRecord(column)
|
||||
val buffer =
|
||||
SnowflakeInsertBuffer(
|
||||
tableName = tableName,
|
||||
columns = columns,
|
||||
snowflakeClient = snowflakeAirbyteClient,
|
||||
snowflakeConfiguration = snowflakeConfiguration,
|
||||
flushLimit = 1,
|
||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
||||
val expectedColumnNames =
|
||||
listOf(
|
||||
"_AIRBYTE_RAW_ID",
|
||||
"_AIRBYTE_EXTRACTED_AT",
|
||||
"_AIRBYTE_META",
|
||||
"_AIRBYTE_GENERATION_ID",
|
||||
"columnName"
|
||||
)
|
||||
|
||||
runBlocking {
|
||||
buffer.accumulate(record)
|
||||
buffer.flush()
|
||||
}
|
||||
|
||||
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
|
||||
coVerify(exactly = 1) {
|
||||
snowflakeAirbyteClient.copyFromStage(
|
||||
tableName,
|
||||
match { it.endsWith("$CSV_FILE_EXTENSION$FILE_SUFFIX") }
|
||||
)
|
||||
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
|
||||
coVerify(exactly = 1) {
|
||||
snowflakeAirbyteClient.copyFromStage(tableName, any(), expectedColumnNames)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testFlushRaw() {
|
||||
val tableName = mockk<TableName>(relaxed = true)
|
||||
fun testFlushToNoStaging() {
|
||||
val tableName = TableName(namespace = "test", name = "table")
|
||||
val column = "columnName"
|
||||
val columns = linkedMapOf(column to "NUMBER(38,0)")
|
||||
columnSchema =
|
||||
ColumnSchema(
|
||||
inputToFinalColumnNames = mapOf(column to column.uppercase()),
|
||||
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
|
||||
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
|
||||
)
|
||||
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
|
||||
val record = createRecord(column)
|
||||
val buffer =
|
||||
SnowflakeInsertBuffer(
|
||||
tableName = tableName,
|
||||
snowflakeClient = snowflakeAirbyteClient,
|
||||
snowflakeConfiguration = snowflakeConfiguration,
|
||||
columnSchema = columnSchema,
|
||||
columnManager = columnManager,
|
||||
snowflakeRecordFormatter = snowflakeRecordFormatter,
|
||||
flushLimit = 1,
|
||||
)
|
||||
val expectedColumnNames =
|
||||
listOf(
|
||||
"_AIRBYTE_RAW_ID",
|
||||
"_AIRBYTE_EXTRACTED_AT",
|
||||
"_AIRBYTE_META",
|
||||
"_AIRBYTE_GENERATION_ID",
|
||||
"columnName"
|
||||
)
|
||||
runBlocking {
|
||||
buffer.accumulate(record)
|
||||
buffer.flush()
|
||||
// In legacy raw mode, it still uses staging
|
||||
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
|
||||
coVerify(exactly = 1) {
|
||||
snowflakeAirbyteClient.copyFromStage(tableName, any(), expectedColumnNames)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testFileCreation() {
|
||||
val tableName = TableName(namespace = "test", name = "table")
|
||||
val column = "columnName"
|
||||
columnSchema =
|
||||
ColumnSchema(
|
||||
inputToFinalColumnNames = mapOf(column to column.uppercase()),
|
||||
finalSchema = mapOf(column.uppercase() to ColumnType("NUMBER", true)),
|
||||
inputSchema = mapOf(column to FieldType(StringType, nullable = true))
|
||||
)
|
||||
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||
val record = createRecord(column)
|
||||
val buffer =
|
||||
SnowflakeInsertBuffer(
|
||||
tableName = tableName,
|
||||
columns = columns,
|
||||
snowflakeClient = snowflakeAirbyteClient,
|
||||
snowflakeConfiguration = snowflakeConfiguration,
|
||||
columnSchema = columnSchema,
|
||||
columnManager = columnManager,
|
||||
snowflakeRecordFormatter = snowflakeRecordFormatter,
|
||||
flushLimit = 1,
|
||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
||||
)
|
||||
|
||||
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
|
||||
|
||||
runBlocking {
|
||||
buffer.accumulate(record)
|
||||
buffer.flush()
|
||||
}
|
||||
|
||||
coVerify(exactly = 1) { snowflakeAirbyteClient.putInStage(tableName, any()) }
|
||||
coVerify(exactly = 1) {
|
||||
snowflakeAirbyteClient.copyFromStage(
|
||||
tableName,
|
||||
match { it.endsWith("$CSV_FILE_EXTENSION$FILE_SUFFIX") }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMissingFields() {
|
||||
val tableName = mockk<TableName>(relaxed = true)
|
||||
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||
val record = createRecord("COLUMN1")
|
||||
val buffer =
|
||||
SnowflakeInsertBuffer(
|
||||
tableName = tableName,
|
||||
columns = linkedMapOf("COLUMN1" to "NUMBER(38,0)", "COLUMN2" to "NUMBER(38,0)"),
|
||||
snowflakeClient = snowflakeAirbyteClient,
|
||||
snowflakeConfiguration = snowflakeConfiguration,
|
||||
flushLimit = 1,
|
||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
||||
)
|
||||
|
||||
runBlocking {
|
||||
buffer.accumulate(record)
|
||||
buffer.csvWriter?.flush()
|
||||
// The csvFilePath is internal, we can access it for testing
|
||||
val filepath = buffer.csvFilePath
|
||||
assertNotNull(filepath)
|
||||
val file = filepath!!.toFile()
|
||||
assert(file.exists())
|
||||
// Close the writer to ensure all data is flushed
|
||||
buffer.csvWriter?.close()
|
||||
assertEquals(
|
||||
"test-value$CSV_FIELD_SEPARATOR$CSV_LINE_DELIMITER",
|
||||
readFromCsvFile(buffer.csvFilePath!!.toFile())
|
||||
)
|
||||
val lines = mutableListOf<String>()
|
||||
GZIPInputStream(file.inputStream()).use { gzip ->
|
||||
BufferedReader(InputStreamReader(gzip)).use { bufferedReader ->
|
||||
bufferedReader.forEachLine { line -> lines.add(line) }
|
||||
}
|
||||
}
|
||||
assertEquals(1, lines.size)
|
||||
file.delete()
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMissingFieldsRaw() {
|
||||
val tableName = mockk<TableName>(relaxed = true)
|
||||
val snowflakeAirbyteClient = mockk<SnowflakeAirbyteClient>(relaxed = true)
|
||||
val record = createRecord("COLUMN1")
|
||||
val buffer =
|
||||
SnowflakeInsertBuffer(
|
||||
tableName = tableName,
|
||||
columns = linkedMapOf("COLUMN1" to "NUMBER(38,0)", "COLUMN2" to "NUMBER(38,0)"),
|
||||
snowflakeClient = snowflakeAirbyteClient,
|
||||
snowflakeConfiguration = snowflakeConfiguration,
|
||||
flushLimit = 1,
|
||||
snowflakeColumnUtils = snowflakeColumnUtils,
|
||||
)
|
||||
|
||||
every { snowflakeConfiguration.legacyRawTablesOnly } returns true
|
||||
|
||||
runBlocking {
|
||||
buffer.accumulate(record)
|
||||
buffer.csvWriter?.flush()
|
||||
buffer.csvWriter?.close()
|
||||
assertEquals(
|
||||
"test-value$CSV_FIELD_SEPARATOR$CSV_LINE_DELIMITER",
|
||||
readFromCsvFile(buffer.csvFilePath!!.toFile())
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun readFromCsvFile(file: File) =
|
||||
GZIPInputStream(file.inputStream()).use { input ->
|
||||
val reader = BufferedReader(InputStreamReader(input))
|
||||
reader.readText()
|
||||
}
|
||||
|
||||
private fun createRecord(columnName: String) =
|
||||
mapOf(
|
||||
columnName to AirbyteValue.from("test-value"),
|
||||
Meta.COLUMN_NAME_AB_EXTRACTED_AT to IntegerValue(System.currentTimeMillis()),
|
||||
Meta.COLUMN_NAME_AB_RAW_ID to StringValue("raw-id"),
|
||||
Meta.COLUMN_NAME_AB_GENERATION_ID to IntegerValue(1223),
|
||||
Meta.COLUMN_NAME_AB_META to StringValue("{\"changes\":[],\"syncId\":43}"),
|
||||
"${columnName}Null" to NullValue
|
||||
private fun createRecord(column: String): Map<String, AirbyteValue> {
|
||||
return mapOf(
|
||||
column to IntegerValue(value = 42),
|
||||
Meta.COLUMN_NAME_AB_GENERATION_ID to NullValue,
|
||||
Meta.COLUMN_NAME_AB_RAW_ID to StringValue("raw-id-1"),
|
||||
Meta.COLUMN_NAME_AB_EXTRACTED_AT to IntegerValue(1234567890),
|
||||
Meta.COLUMN_NAME_AB_META to StringValue("meta-data-foo"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,12 +15,8 @@ import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_LOADED_AT
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_DATA
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
||||
import io.mockk.every
|
||||
import io.mockk.mockk
|
||||
import kotlin.collections.plus
|
||||
import io.airbyte.cdk.load.schema.model.ColumnSchema
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
import org.junit.jupiter.api.BeforeEach
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
private val AIRBYTE_COLUMN_TYPES_MAP =
|
||||
@@ -58,28 +54,16 @@ private fun createExpected(
|
||||
|
||||
internal class SnowflakeRawRecordFormatterTest {
|
||||
|
||||
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
|
||||
|
||||
@BeforeEach
|
||||
fun setup() {
|
||||
snowflakeColumnUtils = mockk {
|
||||
every { getFormattedDefaultColumnNames(any()) } returns
|
||||
AIRBYTE_COLUMN_TYPES_MAP.keys.toList()
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testFormatting() {
|
||||
val columnName = "test-column-name"
|
||||
val columnValue = "test-column-value"
|
||||
val columns = AIRBYTE_COLUMN_TYPES_MAP
|
||||
val record = createRecord(columnName = columnName, columnValue = columnValue)
|
||||
val formatter =
|
||||
SnowflakeRawRecordFormatter(
|
||||
columns = AIRBYTE_COLUMN_TYPES_MAP,
|
||||
snowflakeColumnUtils = snowflakeColumnUtils
|
||||
)
|
||||
val formattedValue = formatter.format(record)
|
||||
val formatter = SnowflakeRawRecordFormatter()
|
||||
// RawRecordFormatter doesn't use columnSchema but still needs one per interface
|
||||
val dummyColumnSchema = ColumnSchema(emptyMap(), emptyMap(), emptyMap())
|
||||
val formattedValue = formatter.format(record, dummyColumnSchema)
|
||||
val expectedValue =
|
||||
createExpected(
|
||||
record = record,
|
||||
@@ -93,33 +77,28 @@ internal class SnowflakeRawRecordFormatterTest {
|
||||
fun testFormattingMigratedFromPreviousVersion() {
|
||||
val columnName = "test-column-name"
|
||||
val columnValue = "test-column-value"
|
||||
val columnsMap =
|
||||
linkedMapOf(
|
||||
COLUMN_NAME_AB_EXTRACTED_AT to "TIMESTAMP_TZ(9)",
|
||||
COLUMN_NAME_AB_LOADED_AT to "TIMESTAMP_TZ(9)",
|
||||
COLUMN_NAME_AB_META to "VARIANT",
|
||||
COLUMN_NAME_DATA to "VARIANT",
|
||||
COLUMN_NAME_AB_RAW_ID to "VARCHAR(16777216)",
|
||||
COLUMN_NAME_AB_GENERATION_ID to "NUMBER(38,0)",
|
||||
)
|
||||
val record = createRecord(columnName = columnName, columnValue = columnValue)
|
||||
val formatter =
|
||||
SnowflakeRawRecordFormatter(
|
||||
columns = columnsMap,
|
||||
snowflakeColumnUtils = snowflakeColumnUtils
|
||||
)
|
||||
val formattedValue = formatter.format(record)
|
||||
val formatter = SnowflakeRawRecordFormatter()
|
||||
// RawRecordFormatter doesn't use columnSchema but still needs one per interface
|
||||
val dummyColumnSchema = ColumnSchema(emptyMap(), emptyMap(), emptyMap())
|
||||
val formattedValue = formatter.format(record, dummyColumnSchema)
|
||||
|
||||
// The formatter outputs in a fixed order regardless of input column order:
|
||||
// 1. AB_RAW_ID
|
||||
// 2. AB_EXTRACTED_AT
|
||||
// 3. AB_META
|
||||
// 4. AB_GENERATION_ID
|
||||
// 5. AB_LOADED_AT
|
||||
// 6. DATA (JSON with remaining columns)
|
||||
val expectedValue =
|
||||
createExpected(
|
||||
record = record,
|
||||
columns = columnsMap,
|
||||
airbyteColumns = columnsMap.keys.toList(),
|
||||
)
|
||||
.toMutableList()
|
||||
expectedValue.add(
|
||||
columnsMap.keys.indexOf(COLUMN_NAME_DATA),
|
||||
"{\"$columnName\":\"$columnValue\"}"
|
||||
)
|
||||
listOf(
|
||||
record[COLUMN_NAME_AB_RAW_ID]!!.toCsvValue(),
|
||||
record[COLUMN_NAME_AB_EXTRACTED_AT]!!.toCsvValue(),
|
||||
record[COLUMN_NAME_AB_META]!!.toCsvValue(),
|
||||
record[COLUMN_NAME_AB_GENERATION_ID]!!.toCsvValue(),
|
||||
record[COLUMN_NAME_AB_LOADED_AT]!!.toCsvValue(),
|
||||
"{\"$columnName\":\"$columnValue\"}"
|
||||
)
|
||||
assertEquals(expectedValue, formattedValue)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,66 +4,58 @@
|
||||
|
||||
package io.airbyte.integrations.destination.snowflake.write.load
|
||||
|
||||
import io.airbyte.cdk.load.component.ColumnType
|
||||
import io.airbyte.cdk.load.data.AirbyteValue
|
||||
import io.airbyte.cdk.load.data.FieldType
|
||||
import io.airbyte.cdk.load.data.IntegerValue
|
||||
import io.airbyte.cdk.load.data.NullValue
|
||||
import io.airbyte.cdk.load.data.StringType
|
||||
import io.airbyte.cdk.load.data.StringValue
|
||||
import io.airbyte.cdk.load.data.csv.toCsvValue
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_EXTRACTED_AT
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_GENERATION_ID
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_META
|
||||
import io.airbyte.cdk.load.message.Meta.Companion.COLUMN_NAME_AB_RAW_ID
|
||||
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
|
||||
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
|
||||
import io.mockk.every
|
||||
import io.mockk.mockk
|
||||
import java.util.AbstractMap
|
||||
import kotlin.collections.component1
|
||||
import kotlin.collections.component2
|
||||
import io.airbyte.cdk.load.schema.model.ColumnSchema
|
||||
import io.airbyte.integrations.destination.snowflake.schema.toSnowflakeCompatibleName
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
import org.junit.jupiter.api.BeforeEach
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
private val AIRBYTE_COLUMN_TYPES_MAP =
|
||||
linkedMapOf(
|
||||
COLUMN_NAME_AB_RAW_ID to "VARCHAR(16777216)",
|
||||
COLUMN_NAME_AB_EXTRACTED_AT to "TIMESTAMP_TZ(9)",
|
||||
COLUMN_NAME_AB_META to "VARIANT",
|
||||
COLUMN_NAME_AB_GENERATION_ID to "NUMBER(38,0)",
|
||||
)
|
||||
.mapKeys { it.key.toSnowflakeCompatibleName() }
|
||||
|
||||
internal class SnowflakeSchemaRecordFormatterTest {
|
||||
|
||||
private lateinit var snowflakeColumnUtils: SnowflakeColumnUtils
|
||||
private fun createColumnSchema(userColumns: Map<String, String>): ColumnSchema {
|
||||
val finalSchema = linkedMapOf<String, ColumnType>()
|
||||
val inputToFinalColumnNames = mutableMapOf<String, String>()
|
||||
val inputSchema = mutableMapOf<String, FieldType>()
|
||||
|
||||
@BeforeEach
|
||||
fun setup() {
|
||||
snowflakeColumnUtils = mockk {
|
||||
every { getFormattedDefaultColumnNames(any()) } returns
|
||||
AIRBYTE_COLUMN_TYPES_MAP.keys.toList()
|
||||
// Add user columns
|
||||
userColumns.forEach { (name, type) ->
|
||||
val finalName = name.toSnowflakeCompatibleName()
|
||||
finalSchema[finalName] = ColumnType(type, true)
|
||||
inputToFinalColumnNames[name] = finalName
|
||||
inputSchema[name] = FieldType(StringType, nullable = true)
|
||||
}
|
||||
|
||||
return ColumnSchema(
|
||||
inputToFinalColumnNames = inputToFinalColumnNames,
|
||||
finalSchema = finalSchema,
|
||||
inputSchema = inputSchema
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testFormatting() {
|
||||
val columnName = "test-column-name"
|
||||
val columnValue = "test-column-value"
|
||||
val columns =
|
||||
(AIRBYTE_COLUMN_TYPES_MAP + linkedMapOf(columnName to "VARCHAR(16777216)")).mapKeys {
|
||||
it.key.toSnowflakeCompatibleName()
|
||||
}
|
||||
val userColumns = mapOf(columnName to "VARCHAR(16777216)")
|
||||
val columnSchema = createColumnSchema(userColumns)
|
||||
val record = createRecord(columnName, columnValue)
|
||||
val formatter =
|
||||
SnowflakeSchemaRecordFormatter(
|
||||
columns = columns as LinkedHashMap<String, String>,
|
||||
snowflakeColumnUtils = snowflakeColumnUtils
|
||||
)
|
||||
val formattedValue = formatter.format(record)
|
||||
val formatter = SnowflakeSchemaRecordFormatter()
|
||||
val formattedValue = formatter.format(record, columnSchema)
|
||||
val expectedValue =
|
||||
createExpected(
|
||||
record = record,
|
||||
columns = columns,
|
||||
columnSchema = columnSchema,
|
||||
)
|
||||
assertEquals(expectedValue, formattedValue)
|
||||
}
|
||||
@@ -72,21 +64,15 @@ internal class SnowflakeSchemaRecordFormatterTest {
|
||||
fun testFormattingVariant() {
|
||||
val columnName = "test-column-name"
|
||||
val columnValue = "{\"test\": \"test-value\"}"
|
||||
val columns =
|
||||
(AIRBYTE_COLUMN_TYPES_MAP + linkedMapOf(columnName to "VARIANT")).mapKeys {
|
||||
it.key.toSnowflakeCompatibleName()
|
||||
}
|
||||
val userColumns = mapOf(columnName to "VARIANT")
|
||||
val columnSchema = createColumnSchema(userColumns)
|
||||
val record = createRecord(columnName, columnValue)
|
||||
val formatter =
|
||||
SnowflakeSchemaRecordFormatter(
|
||||
columns = columns as LinkedHashMap<String, String>,
|
||||
snowflakeColumnUtils = snowflakeColumnUtils
|
||||
)
|
||||
val formattedValue = formatter.format(record)
|
||||
val formatter = SnowflakeSchemaRecordFormatter()
|
||||
val formattedValue = formatter.format(record, columnSchema)
|
||||
val expectedValue =
|
||||
createExpected(
|
||||
record = record,
|
||||
columns = columns,
|
||||
columnSchema = columnSchema,
|
||||
)
|
||||
assertEquals(expectedValue, formattedValue)
|
||||
}
|
||||
@@ -95,23 +81,16 @@ internal class SnowflakeSchemaRecordFormatterTest {
|
||||
fun testFormattingMissingColumn() {
|
||||
val columnName = "test-column-name"
|
||||
val columnValue = "test-column-value"
|
||||
val columns =
|
||||
AIRBYTE_COLUMN_TYPES_MAP +
|
||||
linkedMapOf(
|
||||
columnName to "VARCHAR(16777216)",
|
||||
"missing-column" to "VARCHAR(16777216)"
|
||||
)
|
||||
val userColumns =
|
||||
mapOf(columnName to "VARCHAR(16777216)", "missing-column" to "VARCHAR(16777216)")
|
||||
val columnSchema = createColumnSchema(userColumns)
|
||||
val record = createRecord(columnName, columnValue)
|
||||
val formatter =
|
||||
SnowflakeSchemaRecordFormatter(
|
||||
columns = columns as LinkedHashMap<String, String>,
|
||||
snowflakeColumnUtils = snowflakeColumnUtils
|
||||
)
|
||||
val formattedValue = formatter.format(record)
|
||||
val formatter = SnowflakeSchemaRecordFormatter()
|
||||
val formattedValue = formatter.format(record, columnSchema)
|
||||
val expectedValue =
|
||||
createExpected(
|
||||
record = record,
|
||||
columns = columns,
|
||||
columnSchema = columnSchema,
|
||||
filterMissing = false,
|
||||
)
|
||||
assertEquals(expectedValue, formattedValue)
|
||||
@@ -128,16 +107,37 @@ internal class SnowflakeSchemaRecordFormatterTest {
|
||||
|
||||
private fun createExpected(
|
||||
record: Map<String, AirbyteValue>,
|
||||
columns: Map<String, String>,
|
||||
columnSchema: ColumnSchema,
|
||||
filterMissing: Boolean = true,
|
||||
) =
|
||||
record.entries
|
||||
.associate { entry -> entry.key.toSnowflakeCompatibleName() to entry.value }
|
||||
.map { entry -> AbstractMap.SimpleEntry(entry.key, entry.value.toCsvValue()) }
|
||||
.sortedBy { entry ->
|
||||
if (columns.keys.indexOf(entry.key) > -1) columns.keys.indexOf(entry.key)
|
||||
else Int.MAX_VALUE
|
||||
): List<Any> {
|
||||
val columns = columnSchema.finalSchema.keys.toList()
|
||||
val result = mutableListOf<Any>()
|
||||
|
||||
// Add meta columns first in the expected order
|
||||
result.add(record[COLUMN_NAME_AB_RAW_ID]?.toCsvValue() ?: "")
|
||||
result.add(record[COLUMN_NAME_AB_EXTRACTED_AT]?.toCsvValue() ?: "")
|
||||
result.add(record[COLUMN_NAME_AB_META]?.toCsvValue() ?: "")
|
||||
result.add(record[COLUMN_NAME_AB_GENERATION_ID]?.toCsvValue() ?: "")
|
||||
|
||||
// Add user columns
|
||||
val userColumns =
|
||||
columns.filterNot { col ->
|
||||
listOf(
|
||||
COLUMN_NAME_AB_RAW_ID.toSnowflakeCompatibleName(),
|
||||
COLUMN_NAME_AB_EXTRACTED_AT.toSnowflakeCompatibleName(),
|
||||
COLUMN_NAME_AB_META.toSnowflakeCompatibleName(),
|
||||
COLUMN_NAME_AB_GENERATION_ID.toSnowflakeCompatibleName()
|
||||
)
|
||||
.contains(col)
|
||||
}
|
||||
.filter { (k, _) -> if (filterMissing) columns.contains(k) else true }
|
||||
.map { it.value }
|
||||
|
||||
userColumns.forEach { columnName ->
|
||||
val value = record[columnName] ?: if (!filterMissing) NullValue else null
|
||||
if (value != null || !filterMissing) {
|
||||
result.add(value?.toCsvValue() ?: "")
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user